]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Faster Q5_K and Q6_K on Metal (#2294)
authorKawrakow <redacted>
Thu, 20 Jul 2023 15:19:45 +0000 (18:19 +0300)
committerGitHub <redacted>
Thu, 20 Jul 2023 15:19:45 +0000 (18:19 +0300)
* Faster Q6_K on Metal

* Faster Q5_K on Metal

* Another Q5_K speedup

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index 5e2a2110067bad024e616a0517b2225a9a815bfb..44d046877e85f5897552a4898f9bf130c0bc08b9 100644 (file)
@@ -703,8 +703,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
@@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
                                         } break;
                                     default:
@@ -743,11 +743,14 @@ void ggml_metal_graph_compute(
                                     src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
+                                else if (src0t == GGML_TYPE_Q5_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src0t == GGML_TYPE_Q6_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
                                 else if (src0t == GGML_TYPE_Q2_K ||
-                                         src0t == GGML_TYPE_Q3_K ||
-                                         src0t == GGML_TYPE_Q4_K ||
-                                         src0t == GGML_TYPE_Q5_K ||
-                                         src0t == GGML_TYPE_Q6_K) {
+                                         src0t == GGML_TYPE_Q3_K) {
                                     [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
index a9d134d6e693c5d2f5d0412d6f0defcb62d59f86..f71e8f33b928e0be6b13f8d4c1134d344d83116a 100644 (file)
@@ -1642,39 +1642,39 @@ kernel void kernel_mul_mat_q5_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    float sumf[2]={0.f};
 
-    float sumf = 0;
+    const int step = sizeof(block_q5_K) * nb;
 
 #if QK_K == 256
+#
+    float yl[16], yh[16];
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tpitg.y;   // 0...16
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
-
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int im  = tid/4;
+    const int ir  = tid%4;
+    const int n   = 8;
 
-    const int l0 = n*(2*ir + in);
+    const int l0 = n*ir;
     const int q_offset = 32*im + l0;
     const int y_offset = 64*im + l0;
 
@@ -1683,78 +1683,114 @@ kernel void kernel_mul_mat_q5_K_f32(
     const uint8_t hm3 = hm1 << 4;
     const uint8_t hm4 = hm2 << 4;
 
-    uchar2 sc1, sc2, sc3, sc4;
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    device const float * y1 = yy + ix*QK_K + y_offset;
 
-        device const uint8_t * q1 = (x + i)->qs + q_offset;
-        device const uint8_t * q2 = q1 + 64;
-        device const uint8_t * qh = (x + i)->qh + l0;
-        device const float   * y1 = yy + i*QK_K + y_offset;
-        device const float   * y2 = y1 + 128;
+    for (int i = ix; i < nb; i += 4) {
 
-        const float dall = (float)((x + i)->d);
-        const float dmin = (float)((x + i)->dmin);
+        device const uint8_t * q1 = x[i].qs + q_offset;
+        device const uint8_t * qh = x[i].qh + l0;
+        device const half * dh = &x[i].d;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
 
-        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
-        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
-        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
-        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
-        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+        device const float * y2 = y1 + 128;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+        }
 
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
+        for (int row = 0; row < 2; ++row) {
+
+            device const uint8_t * q2 = q1 + 64;
+
+            sc16[0] = a[0] & kmask1;
+            sc16[1] = a[2] & kmask1;
+            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+            float4 acc = {0.f, 0.f, 0.f, 0.f};
+            for (int l = 0; l < n; ++l) {
+                uint8_t h = qh[l];
+                acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
+                acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
+                acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
+                acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+            }
+            const float dall = dh[0];
+            const float dmin = dh[1];
+            sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
-            s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
-            s[1] += y1[l+32] * ((q1[l] >>  4) + (qh[l] & hm2 ? 16 : 0));
-            s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
-            s[3] += y2[l+32] * ((q2[l] >>  4) + (qh[l] & hm4 ? 16 : 0));
-            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+            q1 += step;
+            qh += step;
+            dh += step/2;
+            a  += step/2;
 
         }
-        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+        y1 += 4 * QK_K;
 
     }
 #else
-    const int il  = 4 * tpitg.x;  // 0, 4, 8, 12
-    const int im  = il/8;         // 0, 0, 1, 1
-    const int in  = il%8;         // 0, 4, 0, 4
+    float yl[8], yh[8];
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    const int il = 4 * (tiisg/8);  // 0, 4, 8, 12
+    const int ix = tiisg%8;
+    const int im = il/8;         // 0, 0, 1, 1
+    const int in = il%8;         // 0, 4, 0, 4
 
-        const float d = (float)x[i].d;
+    device const float * y = yy + ix*QK_K + il;
+
+    for (int i = ix; i < nb; i += 8) {
+
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 4; ++l) {
+            yl[l+0] = y[l+ 0];
+            yl[l+4] = y[l+16];
+            yh[l+0] = y[l+32];
+            yh[l+4] = y[l+48];
+        }
+
+        device const half * dh = &x[i].d;
         device const uint8_t * q = x[i].qs + il;
         device const uint8_t * h = x[i].qh + in;
         device const int8_t  * s = x[i].scales;
-        device const float   * y = yy + i*QK_K + il;
 
-        for (int l = 0; l < 4; ++l) {
-            const uint8_t hl = h[l] >> im;
-            sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
-                  + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
-                  + y[l+32] * d * s[2] * ((q[l+ 0] >>  4) - (hl & 0x10 ? 0 : 16))
-                  + y[l+48] * d * s[3] * ((q[l+16] >>  4) - (hl & 0x40 ? 0 : 16));
+        for (int row = 0; row < 2; ++row) {
+
+            const float d = dh[0];
+
+            float2 acc = {0.f, 0.f};
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t hl = h[l] >> im;
+                acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
+                        + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
+                acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
+                        + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
+            }
+            sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
+
+            q += step;
+            h += step;
+            s += step;
+            dh += step/2;
+
         }
+
+        y += 8 * QK_K;
     }
 #endif
-    sum[ith] = sumf;
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    for (int row = 0; row < 2; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = tot;
+        }
     }
 
 }
@@ -1766,10 +1802,9 @@ kernel void kernel_mul_mat_q6_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
@@ -1781,19 +1816,18 @@ kernel void kernel_mul_mat_q6_K_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const int row = 2 * r0 + sgitg;
 
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     float sumf = 0;
 
 #if QK_K == 256
-    // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
-    const int iqs  = 16 * tpitg.y;
-    const int ip   = iqs / 128;         // 0 or 1
-    const int il   = (iqs - 128*ip)/16; // 0...7
+    const int tid  = tiisg/2;
+    const int ix   = tiisg%2;
+    const int ip   = tid/8;         // 0 or 1
+    const int il   = tid%8;
     const int n    = 4;
     const int l0   = n*il;
     const int is   = 8*ip + l0/16;
@@ -1802,9 +1836,10 @@ kernel void kernel_mul_mat_q6_K_f32(
     const int q_offset_l = 64*ip + l0;
     const int q_offset_h = 32*ip + l0;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    for (int i = ix; i < nb; i += 2) {
 
-        device const uint8_t * ql = x[i].ql + q_offset_l;
+        device const uint8_t * q1 = x[i].ql + q_offset_l;
+        device const uint8_t * q2 = q1 + 32;
         device const uint8_t * qh = x[i].qh + q_offset_h;
         device const int8_t  * sc = x[i].scales + is;
 
@@ -1814,19 +1849,21 @@ kernel void kernel_mul_mat_q6_K_f32(
 
         float4 sums = {0.f, 0.f, 0.f, 0.f};
         for (int l = 0; l < n; ++l) {
-            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
-            sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
-            sums[2] += y[l+64] * ((int8_t)((ql[l+ 0]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
-            sums[3] += y[l+96] * ((int8_t)((ql[l+32]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
         }
 
         sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
     }
+
 #else
-    const int il  = 4*tpitg.x;    // 0, 4, 8, 12
+    const int ix  = tiisg/4;
+    const int il  = 4*(tiisg%4);
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int i = ix; i < nb; i += 8) {
         device const float * y = yy + i * QK_K + il;
         device const uint8_t * ql = x[i].ql + il;
         device const uint8_t * qh = x[i].qh + il;
@@ -1846,23 +1883,8 @@ kernel void kernel_mul_mat_q6_K_f32(
 
 #endif
 
-    sum[ith] = sumf;
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst[r1*ne0 + row] = tot;
     }
-
 }