]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : multi-simd softmax (#3710)
authorGeorgi Gerganov <redacted>
Wed, 1 Nov 2023 19:25:00 +0000 (21:25 +0200)
committerGitHub <redacted>
Wed, 1 Nov 2023 19:25:00 +0000 (21:25 +0200)
ggml-ci

ggml-metal.m
ggml-metal.metal

index bc881395a7aadce12936adfa980e0a2a6a0b8d70..1f034150788e266e6f6a824b725ca6975e3d1ee2 100644 (file)
@@ -1001,11 +1001,15 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_SOFT_MAX:
                         {
-                            const int nth = MIN(32, ne00);
+                            int nth = 32; // SIMD width
 
                             if (ne00%4 == 0) {
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
                             } else {
+                                do {
+                                    nth *= 2;
+                                } while (nth <= ne00 && nth <= 1024);
+                                nth /= 2;
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max];
                             }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1013,8 +1017,9 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_DIAG_MASK_INF:
                         {
index f4b460564453c5b3cfb1e0b545fea0d5e2c6182e..f3152778ae48c34d472d2150761a0773b0cbf2ce 100644 (file)
@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
     device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
     device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
-    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+    float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+
+    for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
         lmax = MAX(lmax, psrc0[i00]);
     }
-    const float max = simd_max(lmax);
+
+    float max = simd_max(lmax);
+    if (tiisg == 0) {
+        buf[sgitg] = max;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    max = buf[0];
 
     // parallel sum
     float lsum = 0.0f;
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
         const float exp_psrc0 = exp(psrc0[i00] - max);
         lsum += exp_psrc0;
         // Remember the result of exp here. exp is expensive, so we really do not
-        // whish to compute it twice.
+        // wish to compute it twice.
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = simd_sum(lsum);
+    float sum = simd_sum(lsum);
+    if (tiisg == 0) {
+        buf[sgitg] = sum;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] += buf[tpitg + i];
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sum = buf[0];
 
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
         pdst[i00] /= sum;
     }
 }
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
     device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
     device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
-    float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
-    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+    float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+
+    for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
         lmax4 = fmax(lmax4, psrc4[i00]);
     }
-    float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
-    const float max = simd_max(lmax);
+    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+    float max = simd_max(lmax);
+    if (tiisg == 0) {
+        buf[sgitg] = max;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    max = buf[0];
 
     // parallel sum
     float4 lsum4 = 0.0f;
-    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         const float4 exp_psrc4 = exp(psrc4[i00] - max);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
-    float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
-    const float sum = simd_sum(lsum);
+    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+    float sum = simd_sum(lsum);
+    if (tiisg == 0) {
+        buf[sgitg] = sum;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           buf[tpitg] += buf[tpitg + i];
+       }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sum = buf[0];
 
-    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         pdst4[i00] /= sum;
     }
 }
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
         dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
     } else {
         dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
-     }
+    }
 }
 
 kernel void kernel_diag_mask_inf_8(