]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : fix ACC op (llama/19427)
authorGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 07:54:03 +0000 (09:54 +0200)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-ops.cpp

index b4ca9c5dd6fb6a062cfe28239cedbc92ee9883ea..3db7f12629181ef1f63c6f03e49265aa608ae9b3 100644 (file)
@@ -1067,8 +1067,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_ADD_ID:
-            return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
+            return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_REPEAT:
         case GGML_OP_CONV_TRANSPOSE_1D:
             return true;
index c04e9fc7ffa61d2e38ea4280315eee052da691f0..3d5db0b79f58525b5efb2f8be4c87da312e258a4 100644 (file)
@@ -620,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
     GGML_ASSERT(op->type         == GGML_TYPE_F32);
 
-    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
 
     const size_t pnb1 = ((const int32_t *) op->op_params)[0];
     const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -671,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     }
 
     ggml_metal_kargs_bin args = {
-        /*.ne00 =*/ ne00,
-        /*.ne01 =*/ ne01,
-        /*.ne02 =*/ ne02,
-        /*.ne03 =*/ ne03,
+        /*.ne00 =*/ ne10,
+        /*.ne01 =*/ ne11,
+        /*.ne02 =*/ ne12,
+        /*.ne03 =*/ ne13,
         /*.nb00 =*/ nb00,
         /*.nb01 =*/ pnb1,
         /*.nb02 =*/ pnb2,
@@ -687,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         /*.nb11 =*/ nb11,
         /*.nb12 =*/ nb12,
         /*.nb13 =*/ nb13,
-        /*.ne0  =*/ ne0,
-        /*.ne1  =*/ ne1,
-        /*.ne2  =*/ ne2,
-        /*.ne3  =*/ ne3,
+        /*.ne0  =*/ ne10,
+        /*.ne1  =*/ ne11,
+        /*.ne2  =*/ ne12,
+        /*.ne3  =*/ ne13,
         /*.nb0  =*/ nb0,
         /*.nb1  =*/ pnb1,
         /*.nb2  =*/ pnb2,
@@ -707,7 +707,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 
-    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+    const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    int nth = 1;
+
+    while (2*nth < args.ne0 && nth < nth_max) {
+        nth *= 2;
+    }
 
     ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);