]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Faster Q2_K on Metal (#2297)
authorKawrakow <redacted>
Fri, 21 Jul 2023 07:44:40 +0000 (10:44 +0300)
committerGitHub <redacted>
Fri, 21 Jul 2023 07:44:40 +0000 (10:44 +0300)
* Faster Q2_K on Metal

* Deleting unnoticed and dangereous trailing white space

* Fixed bug in new metal Q2_K implementation

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index 44d046877e85f5897552a4898f9bf130c0bc08b9..135bda9fc17acfeff8ba77306584e91e4453ed9f 100644 (file)
@@ -676,8 +676,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
@@ -740,7 +740,7 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
-                                    src0t == GGML_TYPE_Q4_K) {
+                                    src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
@@ -749,8 +749,7 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
-                                else if (src0t == GGML_TYPE_Q2_K ||
-                                         src0t == GGML_TYPE_Q3_K) {
+                                else if (src0t == GGML_TYPE_Q3_K) {
                                     [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
index f71e8f33b928e0be6b13f8d4c1134d344d83116a..97f5c10ba772840007188ed753e962f7fa0d5e3e 100644 (file)
@@ -1209,108 +1209,133 @@ kernel void kernel_mul_mat_q2_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
+        constant   int64_t & ne01[[buffer(4)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
 
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-
-    device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
 
-    float sumf = 0;
+    const int step = sizeof(block_q2_K) * nb;
 
 #if QK_K == 256
-    const int tid = tpitg.y;    // 0...16
-    const int il  = tid/4;      // 0...3
-    const int ir  = tid%4;      // 0...3
-    const int ip  = il/2;       // 0 or 1
-    const int shift1 = 4*(il%2);// 0 or 4
-    const int shift2 = shift1+2;// 2 or 6
-    const int n   = 8;
-    const int is  = 4*il + (n*ir)/16;
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int im = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+    const int is = (8*ir)/16;// 0 or 1
 
-    const int y_offset = 64*il + n*ir;
-    const int q_offset = 32*ip + n*ir;
+    device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    for (int ib = ix; ib < nb; ib += 4) {
 
-        device const uint8_t * q = x[i].qs + q_offset;
-        device const uint8_t * scales = x[i].scales + is;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+        }
 
-        uint8_t d1 = scales[0] & 0xF;
-        uint8_t d2 = scales[2] & 0xF;
-        uint8_t m1 = scales[0] >>  4;
-        uint8_t m2 = scales[2] >>  4;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*im + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const half     * dh = &x[ib].d;
 
-        device const float   * y = yy + i*QK_K + y_offset;
+        for (int row = 0; row < N_DST; row++) {
 
-        float2 s = {0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
-            s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
-            s[1] += y[l+32] * ((q[l] >> shift2) & 3);
-            smin += y[l+ 0] * m1 + y[l+32] * m2;
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+            float dall = dh[0];
+            float dmin = dh[1] * 1.f/16.f;
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+            qs += step/2;
+            sc += step;
+            dh += step/2;
         }
 
-        const float dall = (float)x[i].d;
-        const float dmin = (float)x[i].dmin;
-
-        sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
-
+        y4 += 4 * QK_K;
     }
 #else
-    const int il = 4 * tpitg.x;
+    const int ix = tiisg/2;  // 0...15
+    const int it = tiisg%2;  // 0...1
 
-    uint32_t aux[2];
-    thread const uint8_t * d = (thread const uint8_t *)aux;
-    thread const uint8_t * m = (thread const uint8_t *)aux + 4;
+    device const float * y4 = y + ix * QK_K + 8 * it;
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int ib = ix; ib < nb; ib += 16) {
 
-        device const uint8_t * q = x[i].qs + il;
-        device const float   * y = yy + i*QK_K + il;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
+        }
 
-        const float dall = (float)x[i].d;
-        const float dmin = (float)x[i].dmin;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+        device const half     * dh = &x[ib].d;
 
-        device const uint32_t * a = (device const uint32_t *)x[i].scales;
-        aux[0] = a[0] & 0x0f0f0f0f;
-        aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+        for (int row = 0; row < N_DST; row++) {
 
-        for (int l = 0; l < 4; ++l) {
-            sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
-                  + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
-                  + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
-                  + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
+
+            qs += step/2;
+            sc += step;
+            dh += step/2;
         }
+
+        y4 += 16 * QK_K;
     }
 #endif
 
-    sum[ith] = sumf;
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
     }
 }