]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal: optimise `GGML_OP_SUM` (llama/16559)
authorSam/Samuel <redacted>
Wed, 15 Oct 2025 14:05:56 +0000 (23:05 +0900)
committerGeorgi Gerganov <redacted>
Tue, 21 Oct 2025 15:14:33 +0000 (18:14 +0300)
* optimise GGML_OP_SUM

* add non-contiguous tests by permuting the input

* change tests to require full contiguity of OP_SUM

* cuda : add check GGML_OP_SUM

---------

Co-authored-by: Georgi Gerganov <redacted>
src/ggml-cuda/ggml-cuda.cu
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index a5e77672f6e95e2fdd0531a5ef39bfb86bb47951..75fd6db14c514566640472ef4627db0e036f51d7 100644 (file)
@@ -3625,9 +3625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_POOL_2D:
-        case GGML_OP_SUM:
         case GGML_OP_ACC:
             return true;
+        case GGML_OP_SUM:
+            return ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_ARGSORT:
             // TODO: Support arbitrary column width
             return op->src[0]->ne[0] <= 1024;
index 553cf8f5f39ac8a6bdc58262ee2973468ce29407..c3c83abe4e63e55fd214281d0196ce73a32972a9 100644 (file)
@@ -662,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_LOG:
             return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SUM:
+            return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
index 784b7b77851e6873f9c8c4c97c417ceb4fc0211b..4f9f6bda00a7901f21b206f308e51d5a9b7c6652 100644 (file)
@@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
 
+    int nth = 32; // SIMD width
+
+    while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        nth *= 2;
+    }
+
+    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+    nth = std::min(nth, (int) n);
+
+    const int nsg = (nth + 31) / 32;
+
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
+    ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
 
     return 1;
 }
index 6d39ddcc634ef735e379a5710b410f2a17a25022..496610b154b6d7642be0e0d151c0cc4f7af872d0 100644 (file)
@@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
         constant ggml_metal_kargs_sum & args,
         device const float * src0,
         device       float * dst,
-        ushort  tiitg[[thread_index_in_threadgroup]]) {
+        threadgroup  float * shmem_f32 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
 
-    if (tiitg != 0) {
+    if (args.np == 0) {
         return;
     }
 
-    float acc = 0.0f;
-    for (ulong i = 0; i < args.np; ++i) {
-        acc += src0[i];
+    const uint nsg = (ntg.x + 31) / 32;
+
+    float sumf = 0;
+
+    for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
+        sumf += src0[i0];
     }
 
-    dst[0] = acc;
+    sumf = simd_sum(sumf);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float total = 0;
+
+    if (sgitg == 0) {
+        float v = 0;
+
+        if (tpitg.x < nsg) {
+            v = shmem_f32[tpitg.x];
+        }
+
+        total = simd_sum(v);
+
+        if (tpitg.x == 0) {
+            dst[0] = total;
+        }
+    }
 }
 
 template <bool norm>
index 14472dcf124fb0e692894f4561aff9d7419441ad..3a2876c808ffc2e57aa875189702e4159252f5c8 100644 (file)
@@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case {
 struct test_sum : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
+    const std::array<int64_t, 4> permute;
+    bool _use_permute;
 
     std::string vars() override {
-        return VARS_TO_STR2(type, ne);
+        std::string v = VARS_TO_STR2(type, ne);
+        if (_use_permute) v += "," + VAR_TO_STR(permute);
+        return v;
     }
 
     test_sum(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 5, 4, 3})
-        : type(type), ne(ne) {}
+            std::array<int64_t, 4> ne = {10, 5, 4, 3},
+            std::array<int64_t, 4> permute = {0, 0, 0, 0})
+        : type(type), ne(ne), permute(permute),
+            _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_set_param(a);
         ggml_set_name(a, "a");
 
+        if (_use_permute) {
+            a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
+            ggml_set_name(a, "a_permuted");
+        }
+
         ggml_tensor * out = ggml_sum(ctx, a);
         ggml_set_name(out, "out");
 
@@ -6724,6 +6735,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
     test_cases.emplace_back(new test_sum());
     test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3}));  // row-contiguous but non-contiguous
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
     test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
     test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
     test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
@@ -6734,6 +6748,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
     test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
     test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
     test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
     test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));