]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : Q3_K speedup (#2995)
authorKawrakow <redacted>
Fri, 8 Sep 2023 16:01:04 +0000 (18:01 +0200)
committerGitHub <redacted>
Fri, 8 Sep 2023 16:01:04 +0000 (19:01 +0300)
* Slightly faster Q3_K and Q5_K on metal

* Another Q3_K speedup on metal

Combined with previous commit, we are now +9.6% for TG.
PP is not affected as this happens via the matrix multiplication
templates.

* Slowly progressing on Q3_K on metal

We are now 13% faster than master

* nother small improvement for Q3_K on metal

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-metal.metal

index 5070561fba1ace47d5b224d13ec8e9a0a7771773..7b5c21d92ab63fca5a784c3842042ad8c85a9e4a 100644 (file)
@@ -1123,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32(
     device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
     device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
 
-    float yl[16];
+    float yl[32];
 
-    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask1 = 0x3030;
     const uint16_t kmask2 = 0x0f0f;
 
-    const int tid = tiisg/2;
-    const int ix  = tiisg%2;
-    const int ip  = tid/8;          // 0 or 1
-    const int il  = tid/2 - 4*ip;   // 0...3
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int ip  = tid/4;          // 0 or 1
+    const int il  = 2*((tid%4)/2);  // 0 or 2
     const int ir  = tid%2;
     const int n   = 8;
     const int l0  = n*ir;
 
-    const uint16_t m1 = 1 << (4*ip + il);
-    const uint16_t m2 = m1 << 8;
+    // One would think that the Metal compiler would figure out that ip and il can only have
+    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+    // with these two tales.
+    //
+    // Possible masks for the high bit
+    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
+                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
+                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
+                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+    // Possible masks for the low 2 bits
+    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+    const ushort4 hm = mm[2*ip + il/2];
 
     const int shift = 2*il;
-    const uint16_t qm1 = 0x0003 << shift;
-    const uint16_t qm2 = 0x0300 << shift;
-    const int32_t v1 = 4 << shift;
-    const int32_t v2 = 1024 << shift;
+    const float    v1 = il == 0 ? 4.f : 64.f;
+    const float    v2 = 4.f * v1;
 
     const uint16_t s_shift1 = 4*ip;
-    const uint16_t s_shift2 = s_shift1 + 2*(il/2);
-    const int ik = 4 + (il%2);
+    const uint16_t s_shift2 = s_shift1 + il;
 
     const int q_offset = 32*ip + l0;
     const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32(
 
     device const float * y1 = yy + ix*QK_K + y_offset;
 
-    float sumf1[2] = {0.f}, sumf2[2] = {0.f};
-    for (int i = ix; i < nb; i += 2) {
+    uint32_t scales32, aux32;
+    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+    thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+    float sumf1[2] = {0.f};
+    float sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 4) {
 
         for (int l = 0; l < 8; ++l) {
-            yl[l+0] = y1[l+ 0];
-            yl[l+8] = y1[l+16];
+            yl[l+ 0] = y1[l+ 0];
+            yl[l+ 8] = y1[l+16];
+            yl[l+16] = y1[l+32];
+            yl[l+24] = y1[l+48];
         }
 
         device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32(
         for (int row = 0; row < 2; ++row) {
 
             const float d_all = (float)dh[0];
-            const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
 
-            float s1 = 0, s2 = 0;
+            scales16[0] = a[4];
+            scales16[1] = a[5];
+            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+            scales16[0] = a[il+0];
+            scales16[1] = a[il+1];
+            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
             for (int l = 0; l < n; l += 2) {
-                const uint16_t qs = q[l/2];
-                s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
-                s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
+                const int32_t qs = q[l/2];
+                s1 += yl[l+0] * (qs & qm[il/2][0]);
+                s2 += yl[l+1] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+                s4 += yl[l+16] * (qs & qm[il/2][2]);
+                s5 += yl[l+17] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
             }
-            float d = d_all * (s1 + 1.f/256.f * s2);
-            sumf1[row] += d * scales[0];
-            sumf2[row] += d;
+            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[0] - 32);
+            sumf2[row] += d2 * (scales[2] - 32);
 
-            s1 = s2 = 0;
+            s1 = s2 = s3 = s4 = s5 = s6 = 0;
             for (int l = 0; l < n; l += 2) {
-                const uint16_t qs = q[l/2+8];
-                s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
-                s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
+                const int32_t qs = q[l/2+8];
+                s1 += yl[l+8] * (qs & qm[il/2][0]);
+                s2 += yl[l+9] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+                s4 += yl[l+24] * (qs & qm[il/2][2]);
+                s5 += yl[l+25] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
             }
-            d = d_all * (s1 + 1.f/256.f * s2);
-            sumf1[row] += d * scales[1];
-            sumf2[row] += d;
+            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[1] - 32);
+            sumf2[row] += d2 * (scales[3] - 32);
 
             q  += step;
             h  += step;
@@ -1201,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32(
 
         }
 
-        y1 += 2 * QK_K;
+        y1 += 4 * QK_K;
 
     }
 
     for (int row = 0; row < 2; ++row) {
-        const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
-        const float tot = simd_sum(sumf);
-        if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
+        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+        sumf1[row] = simd_sum(sumf);
+    }
+    if (tiisg == 0) {
+        for (int row = 0; row < 2; ++row) {
+            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
         }
     }
+
 }
 #else
 kernel void kernel_mul_mat_q3_K_f32(
@@ -1564,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32(
             sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
             sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
 
-            float4 acc = {0.f, 0.f, 0.f, 0.f};
+            float4 acc1 = {0.f};
+            float4 acc2 = {0.f};
             for (int l = 0; l < n; ++l) {
                 uint8_t h = qh[l];
-                acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
-                acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
-                acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
-                acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
             }
             const float dall = dh[0];
             const float dmin = dh[1];
-            sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
+                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
+                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
             q1 += step;