]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Another speed gain for Q4_0 and Q4_1 on Metal (#2375)
authorKawrakow <redacted>
Tue, 25 Jul 2023 10:48:29 +0000 (13:48 +0300)
committerGitHub <redacted>
Tue, 25 Jul 2023 10:48:29 +0000 (13:48 +0300)
* Another speed gain for Q4_0 and Q4_1 on Metal

* Have N_DST, etc., be template parameters

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.metal

index 987376d560879b97af68f85b066b6e5e4fb5ad6d..696b33ce75cf4fa8d92850ef5e762c32190d6410 100644 (file)
@@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
     }
 }
 
-// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
-    float4 acc = 0.f;
-    device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
-    for (int i = 0; i < 16; i+=2) {
-        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
-        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
-        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    float2 acc = 0.f;
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+    return d * (sumy * -8.f + acc[0] + acc[1]);
 }
 
-// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
     float m = qb_curr->m;
-    float4 acc = 0.f;
-    device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
-    for (int i = 0; i < 16; i+=2) {
-        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
-        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
-        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+    float2 acc = 0.f;
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
+    return d * (acc[0] + acc[1]) + sumy * m;
 }
 
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4 // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-template<typename block_q_type>
+//Note: This is a template, but strictly speaking it only applies to
+//      quantizations where the block size is 32. It also does not
+//      giard against the number of rows not being divisible by
+//      N_DST, so this is another explicit assumption of the implementation.
+template<typename block_q_type, int nr, int nsg, int nw>
 void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
                     int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
                     uint2 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
+    const int first_row = (r0 * nsg + sgitg) * nr;
+    device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
     device const float      * y = (device const float      *) src1 + r1*ne10;
-    float4 y_curr[8];       // src1 vector cache
-    float sumf[N_DST]={0.f}, all_sum;
-    thread float * yl=(thread float *)y_curr;
+    float yl[16];       // src1 vector cache
+    float sumf[nr]={0.f};
 
-    // each thread in a SIMD group deals with 1 block.
-    for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-        float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
-        }
+    const int ix = tiisg/2;
+    const int il = 8*(tiisg%2);
 
-        for (int row = 0; row < N_DST; row++) {
-            sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
-        }
-    }
+    device const float * yb = y + ix * QK4_0 + il;
 
-    // from now loads two rows every time and 16 blocks per row
-    int ir = tiisg / (N_SIMDWIDTH / 2);
-    int ib = tiisg % (N_SIMDWIDTH / 2);
-    for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
-        int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
         float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
+        for (int i = 0; i < 8; i += 2) {
+            sumy += yb[i] + yb[i+1];
+            yl[i+0] = yb[i+ 0];
+            yl[i+1] = yb[i+ 1]/256.f;
+            sumy += yb[i+16] + yb[i+17];
+            yl[i+8] = yb[i+16]/16.f;
+            yl[i+9] = yb[i+17]/4096.f;
         }
 
-        for (int row = 0; row < N_DST; row+=2) {
-            if (nb_start + ib < nb) {
-                sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
-            }
+        for (int row = 0; row < nr; row++) {
+            sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
         }
+
+        yb += QK4_0 * 16;
     }
 
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-            dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + first_row + row] = tot;
         }
     }
 }
@@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_q4_1_f32(
@@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_f16_f32(