]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Faster Q3_K implementation on Metal (#2307)
authorKawrakow <redacted>
Fri, 21 Jul 2023 14:05:30 +0000 (17:05 +0300)
committerGitHub <redacted>
Fri, 21 Jul 2023 14:05:30 +0000 (17:05 +0300)
* Faster Q3_K on Metal

* Additional Q3_K speedup on Metal

* Q3_K for QK_K = 64

* Better Q3_K for QK_K = 64

21.6 ms/t -> 21.1 ms/t

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.m
ggml-metal.metal

index 135bda9fc17acfeff8ba77306584e91e4453ed9f..2810fa2a841c5008d3d68b6b647db517680ab715 100644 (file)
@@ -685,8 +685,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
@@ -743,15 +743,18 @@ void ggml_metal_graph_compute(
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
+                                else if (src0t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+                                }
                                 else if (src0t == GGML_TYPE_Q5_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                                }
-                                else if (src0t == GGML_TYPE_Q3_K) {
-                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
index 97f5c10ba772840007188ed753e962f7fa0d5e3e..5a9a6d842387d06bec1df96c3aa3cb9c2f7b04af 100644 (file)
@@ -351,7 +351,7 @@ kernel void kernel_rms_norm(
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
     // broadcast, simd group number is ntg / 32
-    for (int i = ntg / 32 / 2; i > 0; i /= 2) {
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
        if (tpitg < i) {
            sum[tpitg] += sum[tpitg + i];
        }
@@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32(
     }
 }
 
+#if QK_K == 256
 kernel void kernel_mul_mat_q3_K_f32(
         device const  void * src0,
         device const float * src1,
@@ -1347,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
         constant   int64_t & ne10,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
 
-#if QK_K == 256
+    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
 
-    const uint8_t m3 = 3;
-    const int8_t  m4 = 4;
+    float yl[16];
 
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
 
-    const int tid = tpitg.y;        // expecting 16
+    const int tid = tiisg/2;
+    const int ix  = tiisg%2;
     const int ip  = tid/8;          // 0 or 1
     const int il  = tid/2 - 4*ip;   // 0...3
     const int ir  = tid%2;
     const int n   = 8;
     const int l0  = n*ir;
 
-    const uint8_t m = 1 << (4*ip + il);
+    const uint16_t m1 = 1 << (4*ip + il);
+    const uint16_t m2 = m1 << 8;
 
     const int shift = 2*il;
+    const uint16_t qm1 = 0x0003 << shift;
+    const uint16_t qm2 = 0x0300 << shift;
+    const int32_t v1 = 4 << shift;
+    const int32_t v2 = 1024 << shift;
 
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1389,93 +1391,132 @@ kernel void kernel_mul_mat_q3_K_f32(
     const int q_offset = 32*ip + l0;
     const int y_offset = 128*ip + 32*il + l0;
 
-    //float sumf = 0;
-    float sumf1 = 0, sumf2 = 0;
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    const int step = sizeof(block_q3_K) * nb / 2;
 
-        const float d_all = (float)(x[i].d);
-
-        device const uint8_t * q = x[i].qs + q_offset;
-        device const uint8_t * h = x[i].hmask + l0;
-        device const float   * y = yy + i * QK_K + y_offset;
+    device const float * y1 = yy + ix*QK_K + y_offset;
 
-        device const uint16_t * a = (device const uint16_t *)x[i].scales;
-        const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
+    float sumf1[2] = {0.f}, sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 2) {
 
-        float s = 0;
-        for (int l = 0; l < n; ++l) {
-            s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0];
+            yl[l+8] = y1[l+16];
         }
-        float d = d_all * s;
-        sumf1 += d * scales[0];
-        sumf2 += d;
-        //sumf += d_all * s * (scales[0] - 32);
 
-        s = 0;
-        for (int l = 0; l < n; ++l) {
-            s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+        device const half * dh = &x[i].d;
+
+        for (int row = 0; row < 2; ++row) {
+
+            const float d_all = (float)dh[0];
+            const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
+
+            float s1 = 0, s2 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const uint16_t qs = q[l/2];
+                s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
+                s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
+            }
+            float d = d_all * (s1 + 1.f/256.f * s2);
+            sumf1[row] += d * scales[0];
+            sumf2[row] += d;
+
+            s1 = s2 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const uint16_t qs = q[l/2+8];
+                s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
+                s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
+            }
+            d = d_all * (s1 + 1.f/256.f * s2);
+            sumf1[row] += d * scales[1];
+            sumf2[row] += d;
+
+            q  += step;
+            h  += step;
+            a  += step;
+            dh += step;
+
         }
-        d = d_all * s;
-        sumf1 += d * scales[1];
-        sumf2 += d;
-        //sumf += d_all * s * (scales[1] - 32);
+
+        y1 += 2 * QK_K;
 
     }
 
-    //sum[ith] = sumf;
-    sum[ith] = sumf1 - 32.f*sumf2;
+    for (int row = 0; row < 2; ++row) {
+        const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
+        const float tot = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = tot;
+        }
+    }
+}
 #else
-    const int il = 4 * tpitg.x;  // 0, 4, 8, 12
+kernel void kernel_mul_mat_q3_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    const int row = 2 * r0 + sgitg;
+
+    device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const int ix = tiisg/4;
+    const int il = 4 * (tiisg%4);// 0, 4, 8, 12
     const int im = il/8;         // 0, 0, 1, 1
     const int in = il%8;         // 0, 4, 0, 4
 
-    float sumf = 0;
+    float2 sum = {0.f, 0.f};
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int i = ix; i < nb; i += 8) {
 
         const float d_all = (float)(x[i].d);
 
-        device const uint8_t * q = x[i].qs + il;
-        device const uint8_t * h = x[i].hmask + in;
-        device const float   * y = yy + i * QK_K + il;
-
-        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
-        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
-        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
-        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
-
-        for (int l = 0; l < 4; ++l) {
-            const uint8_t hm = h[l] >> im;
-            sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
-                  + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
-                  + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
-                  + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
+        device const uint16_t * s = (device const uint16_t *)(x[i].scales);
+        device const float    * y = yy + i * QK_K + il;
+
+        const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
+        const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
+        const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
+        const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
+
+        for (int l = 0; l < 4; l += 2) {
+            const uint16_t hm = h[l/2] >> im;
+            sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 :  4))
+                    + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
+                    + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
+                    + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
+            sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
+                    + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
+                    + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
+                    + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
         }
 
     }
+    const float sumf = sum[0] + sum[1] * 1.f/256.f;
 
-    sum[ith] = sumf;
-
-#endif
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst[r1*ne0 + row] = tot;
     }
 
 }
+#endif
 
 #if QK_K == 256
 kernel void kernel_mul_mat_q4_K_f32(
@@ -1773,7 +1814,6 @@ kernel void kernel_mul_mat_q5_K_f32(
 
     for (int i = ix; i < nb; i += 8) {
 
-        float4 sumy = {0.f, 0.f, 0.f, 0.f};
         for (int l = 0; l < 4; ++l) {
             yl[l+0] = y[l+ 0];
             yl[l+4] = y[l+16];