]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal: somewhat faster f16 x f32 matrix multiply kernel (#2951)
authorKawrakow <redacted>
Fri, 1 Sep 2023 08:15:57 +0000 (11:15 +0300)
committerGitHub <redacted>
Fri, 1 Sep 2023 08:15:57 +0000 (11:15 +0300)
* Somewhat faster f16 x f32 matrix multiply kernel

* Better use 32 thread groups for f16 x f32

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index e929c4b07cadd054119bdec9cd6ad0c75f26b631..8c3c64f53f00a7186f9cb12236c5b389a198c57f 100644 (file)
@@ -840,7 +840,7 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                     case GGML_TYPE_F16:
                                         {
-                                            nth0 = 64;
+                                            nth0 = 32;
                                             nth1 = 1;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
                                         } break;
index 82e1a0c7aca06e124898dc1c8caa7064bafef261..02db5323ea0f27956cd6f756b273e8bc2e91df2b 100644 (file)
@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
     device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
-    sum[tpitg.x] = 0.0f;
+    uint ith = tpitg.x;
+    uint nth = tptg.x;
 
-    for (int i = tpitg.x; i < ne00; i += tptg.x) {
-        sum[tpitg.x] += (float) x[i] * (float) y[i];
+    sum[ith] = 0.0f;
+
+    for (int i = ith; i < ne00; i += nth) {
+        sum[ith] += (float) x[i] * (float) y[i];
     }
 
     // accumulate the sum from all threads in the threadgroup
     threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = tptg.x/2; i > 0; i /= 2) {
-        if (tpitg.x < i) {
-            sum[tpitg.x] += sum[tpitg.x + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
     }
-
-    if (tpitg.x == 0) {
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
         dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
     }
+
+    // Original implementation. Left behind commented out for now
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (uint i = tptg.x/2; i > 0; i /= 2) {
+    //    if (tpitg.x < i) {
+    //        sum[tpitg.x] += sum[tpitg.x + i];
+    //    }
+    //    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //}
+    //
+    //if (tpitg.x == 0) {
+    //    dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+    //}
 }
 
 kernel void kernel_alibi_f32(