]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : batch rows copy in a single threadgroup (llama/14384)
authorGeorgi Gerganov <redacted>
Thu, 26 Jun 2025 12:50:15 +0000 (15:50 +0300)
committerGeorgi Gerganov <redacted>
Tue, 1 Jul 2025 08:52:14 +0000 (11:52 +0300)
* metal : batch rows copy in a single threadgroup

ggml-ci

* metal : handle some edge cases when threadgroup size is not a power of 2

ggml-ci

src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal

index 19f4d59e59747452920f1e9d72908681afbde0a2..248fa378ef0f17b8c63be834476e580e42407632 100644 (file)
@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
                     nth *= 2;
                 }
 
+                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
                 nth = MIN(nth, ne00);
 
                 ggml_metal_kargs_sum_rows args = {
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
                     nth *= 2;
                 }
 
+                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
                 nth = MIN(nth, ne00/4);
 
                 ggml_metal_kargs_rms_norm args = {
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
                     nth *= 2;
                 }
 
+                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
                 nth = MIN(nth, ne00/4);
 
                 ggml_metal_kargs_l2_norm args = {
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
                     nth *= 2;
                 }
 
+                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
                 nth = MIN(nth, ne00/4);
 
                 ggml_metal_kargs_norm args = {
@@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+                // TODO: support
+                //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
+                const int32_t nk00 = ne00;
+
+                int nth = 32; // SIMD width
+
+                while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+                // when rows are small, we can batch them together in a single threadgroup
+                int nrptg = 1;
+
+                // TODO: relax this constraint in the future
+                if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
+                    if (nth > nk00) {
+                        nrptg = (nth + nk00 - 1)/nk00;
+                        nth   = nk00;
+
+                        if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                            nrptg--;
+                        }
+                    }
+                }
+
+                nth = MIN(nth, nk00);
+
                 ggml_metal_kargs_cpy args = {
-                    /*.ne00 =*/ ne00,
+                    /*.ne00 =*/ nk00,
                     /*.ne01 =*/ ne01,
                     /*.ne02 =*/ ne02,
                     /*.ne03 =*/ ne03,
@@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
-                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-
+                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
             } break;
         case GGML_OP_SET:
             {
index 3da19879b4b364b993390cf748c60748e9f14964..f028276068ef4beb36c62d14672aa2275894a1fe 100644 (file)
@@ -4306,11 +4306,16 @@ kernel void kernel_cpy(
         device  const char * src0,
         device        char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
+        uint    tiitg[[thread_index_in_threadgroup]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
+        ushort3  tptg[[threads_per_threadgroup]]) {
     const int i03 = tgpig[2];
     const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
+    const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
+
+    if (i01 >= args.ne01) {
+        return;
+    }
 
     const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
@@ -4321,7 +4326,7 @@ kernel void kernel_cpy(
 
     device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
+    for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
         device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
         dst_data[i00] = (T1) src[0];
     }