#endif // FP16_AVAILABLE
}
+// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
+template<bool norm>
+static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
+ const int row = blockIdx.x;
+ const int col = threadIdx.x;
+
+ float sum = 0.0f;
+ for (int i = col; i < ncols; i += blockDim.x) {
+ sum += x[row * ncols + i];
+ }
+
+ sum = warp_reduce_sum(sum);
+
+ if (col != 0) {
+ return;
+ }
+
+ dst[row] = norm ? sum / ncols : sum;
+}
+
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
+ case GGML_OP_MEAN:
+ ggml_cuda_op_mean(ctx, dst);
+ break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
--- /dev/null
+#include "mean.cuh"
+
+void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums(nrows, 1, 1);
+ reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+}
--- /dev/null
+#include "common.cuh"
+
+void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#include "sumrows.cuh"
-static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
- const int row = blockIdx.x;
- const int col = threadIdx.x;
-
- float sum = 0.0f;
- for (int i = col; i < ncols; i += blockDim.x) {
- sum += x[row * ncols + i];
- }
-
- sum = warp_reduce_sum(sum);
-
- if (col == 0) {
- dst[row] = sum;
- }
-}
-
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
- k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+ reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
- sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums(nrows, 1, 1);
+
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
#include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
-
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
+ test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+
return test_cases;
}