]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
SYCL: Add GGML_OP_MEAN operator support (llama/16009)
authoryael-works <redacted>
Thu, 16 Oct 2025 04:21:28 +0000 (07:21 +0300)
committerGeorgi Gerganov <redacted>
Tue, 21 Oct 2025 15:14:33 +0000 (18:14 +0300)
* SYCL: Add GGML_OP_MEAN operator support

* SYCL: Fix formatting for GGML_OP_MEAN case

* Update ggml/src/ggml-sycl/ggml-sycl.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: Sigbjørn Skjæret <redacted>
src/ggml-sycl/ggml-sycl.cpp

index 45b8c216c94d2574af6fd8240b58476d9d476d0e..f3407a813d731de1635523b4ef76116844fcf2cd 100644 (file)
@@ -2151,6 +2151,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
     sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
 }
 
+inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    dpct::queue_ptr main_stream = ctx.stream();
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
+    float *       dst_dd  = static_cast<float *>(dst->data);
+
+    const int64_t ncols = dst->src[0]->ne[0];
+    const int64_t nrows = ggml_nrows(dst->src[0]);
+
+    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
+
+    main_stream->parallel_for(
+        sycl::range<1>(nrows),
+        [=](sycl::id<1> row) {
+            dst_dd[row] /= ncols;
+        }
+    );
+}
+
+
 inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_I32);
@@ -3535,6 +3559,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
     ggml_sycl_op_sum_rows(ctx, dst);
 }
 
+static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
+    ggml_sycl_op_mean(ctx, dst);
+}
+
 static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
     GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -3784,6 +3814,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_SUM_ROWS:
             ggml_sycl_sum_rows(ctx, dst);
             break;
+        case GGML_OP_MEAN:
+            ggml_sycl_mean(ctx, dst);
+            break;
         case GGML_OP_ARGSORT:
             ggml_sycl_argsort(ctx, dst);
             break;
@@ -4431,6 +4464,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_MEAN:
         case GGML_OP_ARGSORT:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_POOL_2D: