From: Sam/Samuel Date: Wed, 15 Oct 2025 14:05:56 +0000 (+0900) Subject: metal: optimise `GGML_OP_SUM` (llama/16559) X-Git-Tag: upstream/1.8.3~461 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=d8a146b0f9a1af396e1812e3fc6859483752dab1;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp metal: optimise `GGML_OP_SUM` (llama/16559) * optimise GGML_OP_SUM * add non-contiguous tests by permuting the input * change tests to require full contiguity of OP_SUM * cuda : add check GGML_OP_SUM --------- Co-authored-by: Georgi Gerganov --- diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a5e77672..75fd6db1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3625,9 +3625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: - case GGML_OP_SUM: case GGML_OP_ACC: return true; + case GGML_OP_SUM: + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGSORT: // TODO: Support arbitrary column width return op->src[0]->ne[0] <= 1024; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 553cf8f5..c3c83abe 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -662,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 784b7b77..4f9f6bda 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + int nth = 32; // SIMD width + + while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, (int) n); + + const int nsg = (nth + 31) / 32; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6d39ddcc..496610b1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, device float * dst, - ushort tiitg[[thread_index_in_threadgroup]]) { + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - if (tiitg != 0) { + if (args.np == 0) { return; } - float acc = 0.0f; - for (ulong i = 0; i < args.np; ++i) { - acc += src0[i]; + const uint nsg = (ntg.x + 31) / 32; + + float sumf = 0; + + for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { + sumf += src0[i0]; } - dst[0] = acc; + sumf = simd_sum(sumf); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float total = 0; + + if (sgitg == 0) { + float v = 0; + + if (tpitg.x < nsg) { + v = shmem_f32[tpitg.x]; + } + + total = simd_sum(v); + + if (tpitg.x == 0) { + dst[0] = total; + } + } } template