]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : update sum_rows kernel to support float4 (llama/19524)
authorGeorgi Gerganov <redacted>
Thu, 12 Feb 2026 09:35:28 +0000 (11:35 +0200)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index 517559d12a667dc720f8a693dcac8709c0eb7d3c..06f3d80459069cc65a0ce26da732da3418bd9a31 100644 (file)
@@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
 }
 
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
-    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 
     char base[256];
     char name[256];
 
-    const char * op_str = "undefined";
+    int op_num = -1;
+
     switch (op->op) {
-        case GGML_OP_SUM_ROWS:
-            op_str = "sum_rows"; break;
-        case GGML_OP_MEAN:
-            op_str = "mean"; break;
+        case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
+        case GGML_OP_MEAN:     op_num = OP_SUM_ROWS_NUM_MEAN;     break;
         default: GGML_ABORT("fatal error");
     };
 
-    snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+    snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d", base, op_num);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
     res.smem = 32*sizeof(float);
 
+    if (is_c4) {
+        res.smem *= 4;
+    }
+
+    res.c4  = is_c4;
+
     return res;
 }
 
index 952e1be076e01e05f42fa6b349f667976a925faf..383e0d6e93b5728e752dcae59898ecaf268a39eb 100644 (file)
@@ -82,6 +82,7 @@
 #define FC_COUNT_EQUAL                 1100
 #define FC_UNARY                       1200
 #define FC_BIN                         1300
+#define FC_SUM_ROWS                    1400
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPSG 8
 #define OP_UNARY_NUM_SOFTPLUS    115
 #define OP_UNARY_NUM_EXPM1       116
 
+#define OP_SUM_ROWS_NUM_SUM_ROWS 10
+#define OP_SUM_ROWS_NUM_MEAN     11
 
 // kernel argument structs
 //
index 7db95d1c84d3a08e1df706b407e67d1aae7bb0e7..20880d9551e2043e31e7340ffb93635dc42e09ef 100644 (file)
@@ -904,6 +904,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
     ggml_metal_kargs_sum_rows args = {
         /*.ne00 =*/ ne00,
         /*.ne01 =*/ ne01,
@@ -925,21 +930,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
 
     auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
 
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
+
     int nth = 32; // SIMD width
 
-    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+    while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
     }
 
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00);
+    nth = std::min(nth, (int) args.ne00);
 
     const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
index 0036ba90ec90712a9fb5e9bab4eae3ed4e65fb6e..6c349aa0c9259d4169a3c2236726b45d73446a6b 100644 (file)
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
     return x*y;
 }
 
+static inline float sum(float x) {
+    return x;
+}
+
+static inline float sum(float4 x) {
+    return x[0] + x[1] + x[2] + x[3];
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template <typename type4x4>
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -1501,33 +1509,35 @@ kernel void kernel_op_sum_f32(
     }
 }
 
-template <bool norm>
-kernel void kernel_sum_rows(
+constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
+
+template <typename T0, typename T>
+kernel void kernel_sum_rows_impl(
         constant ggml_metal_kargs_sum_rows & args,
-        device const float * src0,
-        device       float * dst,
-        threadgroup  float * shmem_f32 [[threadgroup(0)]],
+        device const char * src0,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    int64_t i3 = tgpig.z;
-    int64_t i2 = tgpig.y;
-    int64_t i1 = tgpig.x;
+#define FC_OP  FC_sum_rows_op
 
-    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
-        return;
-    }
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
 
     if (sgitg == 0) {
-        shmem_f32[tiisg] = 0.0f;
+        shmem_t[tiisg] = 0.0f;
     }
 
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
+    device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+    device       T  * dst_row = (device       T  *) (dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
-    float sumf = 0;
+    T0 sumf = T0(0.0f);
 
     for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
         sumf += src_row[i0];
@@ -1538,23 +1548,33 @@ kernel void kernel_sum_rows(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     if (tiisg == 0) {
-        shmem_f32[sgitg] = sumf;
+        shmem_t[sgitg] = sumf;
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    sumf = shmem_f32[tiisg];
+    sumf = shmem_t[tiisg];
     sumf = simd_sum(sumf);
 
     if (tpitg.x == 0) {
-        dst_row[0] = norm ? sumf / args.ne00 : sumf;
+        if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
+            if (is_same<float4, T0>::value) {
+                dst_row[0] = sum(sumf) / (4*args.ne00);
+            } else {
+                dst_row[0] = sum(sumf) / args.ne00;
+            }
+        } else {
+            dst_row[0] = sum(sumf);
+        }
     }
+
+#undef FC_OP
 }
 
-typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
+typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
 
-template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
-template [[host_name("kernel_mean_f32")]]     kernel kernel_sum_rows_t kernel_sum_rows<true>;
+template [[host_name("kernel_sum_rows_f32_f32")]]   kernel kernel_sum_rows_t kernel_sum_rows_impl<float,  float>;
+template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
 
 template<typename T>
 kernel void kernel_cumsum_blk(
@@ -2435,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
     const short K   = FC_solve_tri_k;
     const short NP  = PAD2(N, NW);
 
-    const int32_t ne02 = args.ne02;
-    const int32_t ne03 = args.ne03;
-
     const int32_t i03 = tgpig.z;
     const int32_t i02 = tgpig.y;
     const int32_t i01 = tgpig.x*NSG + sgitg;
@@ -5949,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
     static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
     static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
 
-    const short T = PK + NSG*SH; // shared memory size per query in (half)
+  //const short T = PK + NSG*SH; // shared memory size per query in (half)
 
   //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                      0*PK); // holds the query data
     threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                      0*PK); // same as above but in q4_t
@@ -8537,7 +8554,9 @@ kernel void kernel_mul_mm(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -8660,8 +8679,8 @@ kernel void kernel_mul_mm(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
@@ -8910,7 +8929,9 @@ kernel void kernel_mul_mm_id(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -9045,8 +9066,8 @@ kernel void kernel_mul_mm_id(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
index ed99c24516eb7f225cda976da7e1107f1e3325a5..222b93584170b7818c9174f4dbbd3bd8564f5bfb 100644 (file)
@@ -8132,24 +8132,30 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     }
 
     test_cases.emplace_back(new test_sum());
-    test_cases.emplace_back(new test_sum_rows());
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3}));  // row-contiguous but non-contiguous
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
-    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
-    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
-    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
     test_cases.emplace_back(new test_mean());
-    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
-    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
     test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 }));
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
-    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
     test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
+    test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
     test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
-    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
-    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
     test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
     test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
     test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));