]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal: minor q4 optimization and reduce code size (#2248)
authorShouzheng Liu <redacted>
Thu, 20 Jul 2023 10:32:22 +0000 (06:32 -0400)
committerGitHub <redacted>
Thu, 20 Jul 2023 10:32:22 +0000 (13:32 +0300)
* metal: use uint16_t instead of uint8_t.

Apple GPU doesn't like uint8_t. For every operation on uint8_t
the gpu need to copy the uint8_t to an empty 16 bit register, then
it can issue other instructions.

For the matrix-vector multiplication kernel only, we observed a
340~350 GB/s memory read speed on M1 Max after this commit, which is
very close to the reported hardware limit.

* metal: update rms_norm kernel

This commit double the speed of rms_norm operations by using 512 threads
per threadgroup, combining with SIMD primitives to minimize the need for
thread group barriers.

* metal: use template to reduce size

Revert modifications on block_q4_0 and block_q4_1.

ggml-metal.m
ggml-metal.metal

index ee205bcdf773ca1666016ac63aae16a7aafdfb96..d80a380d7903e2cfd9dd728f0b265810cf6460b2 100644 (file)
@@ -792,7 +792,7 @@ void ggml_metal_graph_compute(
 
                             const float eps = 1e-6f;
 
-                            const int nth = 256;
+                            const int nth = 512;
 
                             [encoder setComputePipelineState:ctx->pipeline_rms_norm];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -800,7 +800,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
index 9f9a4fbd74446e2c203fe2f8c65653ee6f5360f7..ee56336acfa7b5ea0a3cb81ae75256051a14705d 100644 (file)
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
         threadgroup float  * sum [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+    device const float * x_scalar = (device const float *) x;
+    float4 sumf=0;
+    float all_sum=0;
 
     // parallel sum
-    sum[tpitg] = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        sum[tpitg] += x[i00] * x[i00];
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        sumf += x[i00] * x[i00];
+    }
+    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+    all_sum = simd_sum(all_sum);
+    if (tiisg == 0) {
+        sum[sgitg] = all_sum;
     }
 
-    // reduce
     threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg/2; i > 0; i /= 2) {
-        if (tpitg < i) {
-            sum[tpitg] += sum[tpitg + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    // broadcast, simd group number is ntg / 32
+    for (int i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           sum[tpitg] += sum[tpitg + i];
+       }
     }
-
-    // broadcast
     if (tpitg == 0) {
+        for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
         sum[0] /= ne00;
     }
 
@@ -359,104 +366,102 @@ kernel void kernel_rms_norm(
     const float mean  = sum[0];
     const float scale = 1.0f/sqrt(mean + eps);
 
-    device float * y = dst + tgpig*ne00;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    device float4 * y = (device float4 *) (dst + tgpig*ne00);
+    device float * y_scalar = (device float *) y;
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         y[i00] = x[i00] * scale;
     }
+    if (tpitg == 0) {
+        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
+    }
+}
+
+// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+    float d = qb_curr->d;
+    float4 acc = 0.f;
+    device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
+    for (int i = 0; i < 16; i+=2) {
+        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
+        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    }
+    return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+}
+
+// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+    float4 acc = 0.f;
+    device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
+    for (int i = 0; i < 16; i+=2) {
+        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
+        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    }
+    return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
 }
 
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4 // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-kernel void kernel_mul_mat_q4_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
-        constant   int64_t & ne01[[buffer(4)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+template<typename block_q_type>
+void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
+                    int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
+                    uint2 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
+    device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
     device const float      * y = (device const float      *) src1 + r1*ne10;
-    block_q4_0 qb_curr, qb_next;
     float4 y_curr[8];       // src1 vector cache
     float sumf[N_DST]={0.f}, all_sum;
     thread float * yl=(thread float *)y_curr;
 
-    // bootstrap
-    qb_curr = x[tiisg];
     // each thread in a SIMD group deals with 1 block.
     for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-
         float sumy = 0;
         for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
+            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
             sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
         }
-        sumy *= (-8.f);
 
         for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            float d = qb_curr.d;
-            float acc = sumy;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
-            }
-            sumf[row] += d * acc;
-            qb_curr = qb_next;
+            sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
         }
     }
 
-    if (nb % N_SIMDWIDTH == 0) {
-        for (int row = 0; row < N_DST; ++row) {
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
-        }
-    } else {
-
+    // from now loads two rows every time and 16 blocks per row
+    int ir = tiisg / (N_SIMDWIDTH / 2);
+    int ib = tiisg % (N_SIMDWIDTH / 2);
+    for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
+        int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
         float sumy = 0;
         for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
+            y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
             sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
         }
-        sumy *= (-8.f);
 
-        for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            float d = qb_curr.d;
-            float acc = sumy;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
+        for (int row = 0; row < N_DST; row+=2) {
+            if (nb_start + ib < nb) {
+                sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
             }
-            if (tiisg < nb % N_SIMDWIDTH) {
-                sumf[row] += d * acc;
-            }
-            qb_curr = qb_next;
+        }
+    }
 
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+            dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
         }
     }
 }
 
-kernel void kernel_mul_mat_q4_1_f32(
+kernel void kernel_mul_mat_q4_0_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -467,80 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int nb = ne00/QK4_0;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
-    device const float      * y = (device const float      *) src1 + r1*ne10;
-    block_q4_1 qb_curr, qb_next;
-    float4 y_curr[8];       // src1 vector cache
-    float sumf[N_DST]={0.f}, all_sum;
-    thread float * yl=(thread float *)y_curr;
-
-    // bootstrap
-    qb_curr = x[tiisg];
-    // each thread in a SIMD group deals with 1 block.
-    for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-
-        float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
-        }
-
-        for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            const float d = qb_curr.d;
-            const float m = qb_curr.m;
-            float acc = 0.f;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
-            }
-            sumf[row] += d * acc + m * sumy;
-            qb_curr = qb_next;
-        }
-    }
-
-    if (nb % N_SIMDWIDTH == 0) {
-        for (int row = 0; row < N_DST; ++row) {
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
-        }
-    } else {
-
-        float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
-        }
-
-        for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            const float d = qb_curr.d;
-            const float m = qb_curr.m;
-            float acc = 0.f;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
-            }
-            if (tiisg < nb % N_SIMDWIDTH) {
-                sumf[row] += d * acc + m * sumy;
-            }
-            qb_curr = qb_next;
+    mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+}
 
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
-        }
-    }
+kernel void kernel_mul_mat_q4_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne01[[buffer(4)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+     mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_f16_f32(