]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : avoid divisions in bin kernel (#20426)
authorGeorgi Gerganov <redacted>
Thu, 12 Mar 2026 07:42:40 +0000 (09:42 +0200)
committerGitHub <redacted>
Thu, 12 Mar 2026 07:42:40 +0000 (09:42 +0200)
* metal : avoid modulus in bin kernel when not broadcasting

* metal : fix capture_started flag

ggml/src/ggml-metal/ggml-metal-context.m
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 855fd1adae8e674e0a4450fa817bfb673decb8bc..32d97cd5d0af187d98c557b57409dcc80b55aa4e 100644 (file)
@@ -554,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
 
         // enter here only when capturing in order to wait for all computation to finish
         // otherwise, we leave the graph to compute asynchronously
-        if (!use_capture && ctx->capture_started) {
+        if (use_capture && ctx->capture_started) {
             // wait for completion and check status of each command buffer
             // needed to detect if the device ran out-of-memory for example (#1881)
             {
@@ -606,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
 
             [ctx->capture_scope endScope];
             [[MTLCaptureManager sharedCaptureManager] stopCapture];
+
+            ctx->capture_started = false;
         }
     }
 
index 15ae2e517df75ac293524bac10a452c35c8f9f82..72ad876d5e48db3943b522c45c1d6416bc7ff58b 100644 (file)
@@ -1470,10 +1470,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
 
     const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
 
+    const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
     const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
 
     snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
-    snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
+    snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
@@ -1482,6 +1483,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
         ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
         ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
         ggml_metal_cv_set_bool (cv, is_rb,  FC_BIN + 2);
+        ggml_metal_cv_set_bool (cv, is_cb,  FC_BIN + 3);
 
         res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
index 306dbcf36609d303bd7fb4652c62dfad148e54e3..c0bcad392b9440ba18f2d82c57865c0ed50f7f2d 100644 (file)
@@ -3180,9 +3180,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
     ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);
 
     if (pipeline.cnt) {
-        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
-
-        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+        ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
     } else {
         const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
index 0b77d5349b861f03101c766d6f5490ca0ecc24a4..24a3092af22c65fac7e5b87606d2e40132dbc399 100644 (file)
@@ -1111,6 +1111,7 @@ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_un
 constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
 constant short FC_bin_f  [[function_constant(FC_BIN + 1)]];
 constant bool  FC_bin_rb [[function_constant(FC_BIN + 2)]];
+constant bool  FC_bin_cb [[function_constant(FC_BIN + 3)]];
 
 template <typename T0, typename T1, typename T>
 kernel void kernel_bin_fuse_impl(
@@ -1124,11 +1125,12 @@ kernel void kernel_bin_fuse_impl(
 #define FC_OP FC_bin_op
 #define FC_F  FC_bin_f
 #define FC_RB FC_bin_rb
+#define FC_CB FC_bin_cb
 
     if (FC_RB) {
         // row broadcast
-        const uint i0 = tgpig.x;
-        const uint i1 = i0%args.ne10;
+        const uint i0 = tgpig.y*args.ne00 + tgpig.x;
+        const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
 
         device const T0 * src0_row = (device const T0 *) (src0);
         device       T  * dst_row  = (device       T  *) (dst);
@@ -1200,7 +1202,7 @@ kernel void kernel_bin_fuse_impl(
             device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
 
             for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-                const int i10 = i0%args.ne10;
+                const int i10 = FC_CB ? i0%args.ne10 : i0;
 
                 if (FC_OP == 0) {
                     dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
@@ -1225,7 +1227,7 @@ kernel void kernel_bin_fuse_impl(
             }
 
             for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-                const int i10 = i0%args.ne10;
+                const int i10 = FC_CB ? i0%args.ne10 : i0;
 
                 T res = src0_ptr[i0];
 
@@ -1261,6 +1263,7 @@ kernel void kernel_bin_fuse_impl(
 #undef FC_OP
 #undef FC_F
 #undef FC_RB
+#undef FC_CB
 }
 
 typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;