]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Faster Q4_K on Metal (#2290)
authorKawrakow <redacted>
Thu, 20 Jul 2023 12:18:43 +0000 (15:18 +0300)
committerGitHub <redacted>
Thu, 20 Jul 2023 12:18:43 +0000 (15:18 +0300)
Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index d80a380d7903e2cfd9dd728f0b265810cf6460b2..5e2a2110067bad024e616a0517b2225a9a815bfb 100644 (file)
@@ -694,8 +694,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
@@ -739,7 +739,8 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
 
-                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+                                    src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q2_K ||
index ee56336acfa7b5ea0a3cb81ae75256051a14705d..a9d134d6e693c5d2f5d0412d6f0defcb62d59f86 100644 (file)
@@ -1452,6 +1452,7 @@ kernel void kernel_mul_mat_q3_K_f32(
 
 }
 
+#if QK_K == 256
 kernel void kernel_mul_mat_q4_K_f32(
         device const  void * src0,
         device const float * src1,
@@ -1459,131 +1460,180 @@ kernel void kernel_mul_mat_q4_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
+        constant   int64_t & ne01[[buffer(4)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
-
-    const int nb = ne00/QK_K;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
-
-    device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    float sumf = 0;
-
-#if QK_K == 256
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tpitg.y;   // 0...16
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int im = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
 
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST]={0.f}, all_sum;
 
-    const int l0 = n*(2*ir + in);
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
+    const int step = sizeof(block_q4_K) * nb / 2;
 
-    uchar2 sc1, sc2, sc3, sc4;
+    device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
-        device const uint8_t * q1 = (x + i)->qs + q_offset;
-        device const uint8_t * q2 = q1 + 64;
-        device const float   * y1 = yy + i*QK_K + y_offset;
-        device const float   * y2 = y1 + 128;
+    for (int ib = ix; ib < nb; ib += 4) {
 
-        const float dall = (float)((x + i)->d);
-        const float dmin = (float)((x + i)->dmin);
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
+            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+        }
 
-        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
-        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
-        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
-        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
-        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const half     * dh = &x[ib].d;
 
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
+        for (int row = 0; row < N_DST; row++) {
 
-            s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
-            s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
-            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            device const uint16_t * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+            }
 
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += step;
+            sc += step;
+            dh += step;
         }
-        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
+        y4 += 4 * QK_K;
     }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
+    }
+}
 #else
-    uint16_t aux16[2];
-    thread const uint8_t * scales = (thread const uint8_t *)aux16;
+kernel void kernel_mul_mat_q4_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne01[[buffer(4)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    const int il  = 4*tpitg.x;
+    const int ix = tiisg/4;  // 0...7
+    const int it = tiisg%4;  // 0...3
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[8];
+    float yh[8];
+    float sumf[N_DST]={0.f}, all_sum;
 
-        device const uint8_t * q = x[i].qs + il;
-        device const float   * y = yy + i * QK_K + il;
+    const int step = sizeof(block_q4_K) * nb / 2;
 
-        const float d = (float)x[i].d[0];
-        const float m = (float)x[i].d[1];
+    device const float * y4 = y + ix * QK_K + 8 * it;
 
-        device const uint16_t * a = (device const uint16_t *)x[i].scales;
-        aux16[0] = a[0] & 0x0f0f;
-        aux16[1] = (a[0] >> 4) & 0x0f0f;
+    uint16_t sc16[4];
 
-        for (int l = 0; l < 4; ++l) {
-            sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
-                  + d * scales[1] * (y[l+32] * (q[l] >>  4) + y[l+48] * (q[l+16] >>  4)) - m * scales[3] * (y[l+32] + y[l+48]);
+    for (int ib = ix; ib < nb; ib += 8) {
+
+        float2 sumy = {0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i] = y4[i+ 0]; sumy[0] += yl[i];
+            yh[i] = y4[i+32]; sumy[1] += yh[i];
         }
-    }
-#endif
 
-    sum[ith] = sumf;
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+        device const half     * dh = x[ib].d;
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    // This version is slightly faster than the commented out one below,
-    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
-    }
+        for (int row = 0; row < N_DST; row++) {
+
+            sc16[0] = sc[0] & 0x000f;
+            sc16[1] = sc[0] & 0x0f00;
+            sc16[2] = sc[0] & 0x00f0;
+            sc16[3] = sc[0] & 0xf000;
+
+            float2 acc1 = {0.f, 0.f};
+            float2 acc2 = {0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
+                acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
+                acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
+                         dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
+
+            qs += step;
+            sc += step;
+            dh += step;
+        }
 
-    //// accumulate the sum from all threads in the threadgroup
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (uint i = nth/2; i > 0; i /= 2) {
-    //    if (ith < i) {
-    //        sum[ith] += sum[ith + i];
-    //    }
-    //    threadgroup_barrier(mem_flags::mem_threadgroup);
-    //}
+        y4 += 8 * QK_K;
+    }
 
-    //if (ith == 0) {
-    //    dst[r1*ne0 + r0] = sum[0];
-    //}
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
+    }
 }
+#endif
 
 kernel void kernel_mul_mat_q5_K_f32(
         device const  void * src0,