]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : more optimizations (#2959)
authorKawrakow <redacted>
Sun, 3 Sep 2023 08:06:22 +0000 (11:06 +0300)
committerGitHub <redacted>
Sun, 3 Sep 2023 08:06:22 +0000 (11:06 +0300)
* Very minor speedup via simd-group synchronization in f16 x f32

* Another very minor speedup on metal

* Quite significant PP speedup on metal

* Another attempt

* Minor

* Massive improvement for TG for fp16

* ~4-5% improvement for Q8_0 TG on metal

---------

Co-authored-by: Iwan Kawrakow <redacted>
Co-authored-by: Georgi Gerganov <redacted>
ggml-metal.m
ggml-metal.metal

index 88e7e13569c0bb0e5ebe73c2ec438ccc94d7c1bd..d0d23442eab5d83b74768d0b74bef674861d5e06 100644 (file)
@@ -76,6 +76,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -219,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -284,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -868,7 +871,11 @@ void ggml_metal_graph_compute(
                                         {
                                             nth0 = 32;
                                             nth1 = 1;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                            if (ne11 * ne12 < 4) {
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+                                            } else {
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                            }
                                         } break;
                                     case GGML_TYPE_Q4_0:
                                         {
@@ -920,8 +927,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 2;
-                                            nth1 = 32;
+                                            nth0 = 4; //1;
+                                            nth1 = 8; //32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
@@ -969,9 +976,12 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
-                                    src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
+                                    src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
+                                else if (src0t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -985,8 +995,8 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
-                                    [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    int64_t ny = (ne11 + 3)/4;
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
                         } break;
index 8cdf0b9d2ba0a39aa92c8fbdd02e05795839e529..3fa311b4027f93bfd277cddf84015ffd6caa8a0f 100644 (file)
@@ -133,19 +133,24 @@ kernel void kernel_soft_max(
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    // broadcast
-    if (tpitg[0] == 0) {
-        buf[0] = buf[0];
-    }
+    //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
+    //               the loop, and when that is done, buf[0] has the correct (synchronized) value
+    //if (tpitg[0] == 0) {
+    //    buf[0] = buf[0];
+    //}
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
 
     const float max = buf[0];
 
     // parallel sum
     buf[tpitg[0]] = 0.0f;
     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        buf[tpitg[0]] += exp(psrc0[i00] - max);
+        const float exp_psrc0 = exp(psrc0[i00] - max);
+        buf[tpitg[0]] += exp_psrc0;
+        // Remember the result of exp here. exp is expensive, so we really do not
+        // whish to compute it twice.
+        pdst[i00] = exp_psrc0;
     }
 
     // reduce
@@ -157,17 +162,18 @@ kernel void kernel_soft_max(
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    // broadcast
-    if (tpitg[0] == 0) {
-        buf[0] = buf[0];
-    }
+    // broadcast - not needed, see above
+    //// broadcast
+    //if (tpitg[0] == 0) {
+    //    buf[0] = buf[0];
+    //}
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
 
     const float sum = buf[0];
 
     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        pdst[i00] = exp(psrc0[i00] - max) / sum;
+        pdst[i00] /= sum;
     }
 }
 
@@ -214,25 +220,27 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    // broadcast
-    if (tpitg == 0) {
-        sum[0] /= ne00;
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //// broadcast
+    //if (tpitg == 0) {
+    //    sum[0] /= ne00;
+    //}
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
     const float mean  = sum[0];
 
-    // recenter
+    // recenter and VARIANCE
     device float * y = dst + tgpig*ne00;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        y[i00] = x[i00] - mean;
-    }
-
-    // VARIANCE
-    // parallel sum
     sum[tpitg] = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] - mean;
         sum[tpitg] += y[i00] * y[i00];
     }
+
+    //// VARIANCE
+    //// parallel sum
+    //sum[tpitg] = 0.0f;
+    //for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    //    sum[tpitg] += y[i00] * y[i00];
+    //}
     // reduce
     threadgroup_barrier(mem_flags::mem_threadgroup);
     for (uint i = ntg/2; i > 0; i /= 2) {
@@ -241,11 +249,11 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    // broadcast
-    if (tpitg == 0) {
-        sum[0] /= ne00;
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //// broadcast
+    //if (tpitg == 0) {
+    //    sum[0] /= ne00;
+    //}
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
     const float variance = sum[0];
 
     const float scale = 1.0f/sqrt(variance + eps);
@@ -435,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32(
      mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
+#define NB_Q8_0 8
+
 kernel void kernel_mul_mat_q8_0_f32(
         device const  void * src0,
         device const float * src1,
@@ -463,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32(
     device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
     device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
-    float yl[16];
+    float yl[NB_Q8_0];
     float sumf[nr]={0.f};
 
-    const int ix = tiisg/2;
-    const int il = tiisg%2;
+    const int ix = tiisg/4;
+    const int il = tiisg%4;
 
-    device const float * yb = y + ix * QK8_0 + 16*il;
+    device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
 
-    // each thread in a SIMD group deals with half a block.
-    for (int ib = ix; ib < nb; ib += nw/2) {
-        for (int i = 0; i < 16; ++i) {
+    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+    for (int ib = ix; ib < nb; ib += nw/4) {
+        for (int i = 0; i < NB_Q8_0; ++i) {
             yl[i] = yb[i];
         }
 
         for (int row = 0; row < nr; row++) {
-            device const int8_t * qs = x[ib+row*nb].qs + 16*il;
+            device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
             float sumq = 0.f;
-            for (int iq = 0; iq < 16; ++iq) {
+            for (int iq = 0; iq < NB_Q8_0; ++iq) {
                 sumq += qs[iq] * yl[iq];
             }
             sumf[row] += sumq*x[ib+row*nb].d;
         }
 
-        yb += QK8_0 * 16;
+        yb += NB_Q8_0 * nw;
     }
 
     for (int row = 0; row < nr; ++row) {
@@ -497,7 +507,7 @@ kernel void kernel_mul_mat_q8_0_f32(
     }
 }
 
-kernel void kernel_mul_mat_f16_f32(
+kernel void kernel_mul_mat_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -515,11 +525,8 @@ kernel void kernel_mul_mat_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tpig[[thread_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
@@ -528,42 +535,102 @@ kernel void kernel_mul_mat_f16_f32(
     device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
-    uint ith = tpitg.x;
-    uint nth = tptg.x;
+    float sumf = 0;
+    if (ne00 < 128) {
+        for (int i = tiisg; i < ne00; i += 32) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    } else {
+        device const half4  * x4 = (device const half4  *) x;
+        device const float4 * y4 = (device const float4 *) y;
+        for (int i = tiisg; i < ne00/4; i += 32) {
+            for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
 
-    sum[ith] = 0.0f;
+}
 
-    for (int i = ith; i < ne00; i += nth) {
-        sum[ith] += (float) x[i] * (float) y[i];
-    }
+#define N_F16_F32 4
 
-    // accumulate the sum from all threads in the threadgroup
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+kernel void kernel_mul_mat_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = N_F16_F32*tgpig.y;
+    const int64_t im = tgpig.z;
+
+    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
     }
 
-    // Original implementation. Left behind commented out for now
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (uint i = tptg.x/2; i > 0; i /= 2) {
-    //    if (tpitg.x < i) {
-    //        sum[tpitg.x] += sum[tpitg.x + i];
-    //    }
-    //    threadgroup_barrier(mem_flags::mem_threadgroup);
-    //}
-    //
-    //if (tpitg.x == 0) {
-    //    dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
-    //}
 }
 
 kernel void kernel_alibi_f32(
@@ -1262,7 +1329,8 @@ kernel void kernel_mul_mat_q4_K_f32(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int r2 = tgpig.z;
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = r0 * N_DST;
     const int ib_row = first_row * nb;
     const uint offset0 = r2/gqa*(nb*ne0);
     device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;