]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Some more Q4_K and Q5_K speedup on CUDA (#2346)
authorKawrakow <redacted>
Sun, 23 Jul 2023 21:19:47 +0000 (00:19 +0300)
committerGitHub <redacted>
Sun, 23 Jul 2023 21:19:47 +0000 (00:19 +0300)
* Faster Q5_K on CUDA

* Small Q5_K improvement on older GPUs

* Spped up Q4_K on CUDA

GTX1660: 29.5 ms/t -> 25.6 ms/t
RTX4080: 8.40 ms/t -> 8.25 ms/t

* Spped up Q4_K on CUDA

GTX1660: 36.7 ms/t -> 35.6 ms/t
RTX4080:  9.8 ms/t ->  9.5 ms/t

* Address PR comments

* Add some comments to satisfy PR reviewer

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu

index 6fb55d838dfb3e4a7ec505f17a3adf334e085ffe..6823adc6cc958172e16e728d5d75a6ab15dfb06f 100644 (file)
@@ -1073,10 +1073,12 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
+    uint16_t q16[8];
+    const uint8_t * q4 = (const uint8_t *)q16;
+
     for (int i = ix; i < num_blocks_per_row; i += 2) {
 
         const uint8_t * ql1 = x[i].qs + q_offset;
-        const uint8_t * ql2 = ql1 + 64;
         const uint8_t * qh  = x[i].qh + l0;
         const float   * y1  = yy + i*QK_K + y_offset;
         const float   * y2  = y1 + 128;
@@ -1092,15 +1094,25 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
 
         float4 sum = {0.f, 0.f, 0.f, 0.f};
         float smin = 0;
+        const uint16_t * q1 = (const uint16_t *)ql1;
+        const uint16_t * q2 = q1 + 32;
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[8] & 0x0f0f;
+        q16[2] = (q1[0] >> 4) & 0x0f0f;
+        q16[3] = (q1[8] >> 4) & 0x0f0f;
+        q16[4] = q2[0] & 0x0f0f;
+        q16[5] = q2[8] & 0x0f0f;
+        q16[6] = (q2[0] >> 4) & 0x0f0f;
+        q16[7] = (q2[8] >> 4) & 0x0f0f;
         for (int l = 0; l < n; ++l) {
-            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
-                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
-            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
-                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
-            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
-                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
-            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
-                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
             smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
                   + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
         }
@@ -1554,7 +1566,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_K * bq4_K = (const block_q4_K *) vbq;
 
-    const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
+    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
 
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
@@ -1562,7 +1575,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     const float    d = bq4_K->d;
     const float dmin = bq4_K->dmin;
 
-    const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
+    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+    const int v1 = q4[0];
+    const int v2 = q4[4];
 
     const uint16_t * scales = (const uint16_t *)bq4_K->scales;
     uint16_t aux[2];
@@ -1580,13 +1600,19 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     for (int i = 0; i < QR4_K; ++i) {
 
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
         const float d8i = bq8i->d;
+        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int ui1 = q8[0];
+        const int ui2 = q8[4];
 
-        const int vi = (v >> (4*i)) & 0x0F0F0F0F;
+        const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
+        const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc[i]); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+
+        sumf_d += d8i * (dot1 * sc[i]);
+        sumf_m += d8i * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
     }
 
     return d*sumf_d - dmin*sumf_m;
@@ -1601,7 +1627,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_K * bq5_K = (const block_q5_K *) vbq;
 
-    const int bq8_offset = QR5_K * (iqs / QI8_1);
+    const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
 
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
@@ -1609,28 +1637,48 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     const float    d = bq5_K->d;
     const float dmin = bq5_K->dmin;
 
-    const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
 
-    const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
+    const int vh1 = qh[0] >> bq8_offset;
+    const int vh2 = qh[4] >> bq8_offset;
 
-    for (int i = 0; i < QR5_K; ++i) {
-        const int isc = bq8_offset + i;
+    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
 
-        uint8_t sc, m;
-        get_scale_min_k4(isc, bq5_K->scales, sc, m);
+    for (int i = 0; i < QR5_K; ++i) {
 
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
         const float d8i = bq8i->d;
+        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int ui1 = q8[0];
+        const int ui2 = q8[4];
 
-        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+        const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F;
+        const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih1 = ((vh1 >> i) << 4) & 0x10101010;
+        const int vih2 = ((vh2 >> i) << 4) & 0x10101010;
+
+        const int vi1 = vil1 | vih1;
+        const int vi2 = vil2 | vih2;
 
-        const int vih = ((vh >> i) << 4) & 0x10101010;
+        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
 
-        const int vi = vil | vih;
+        sumf_d += d8i * (dot1 * sc[i]);
+        sumf_m += d8i * (dot2 * m[i]);
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q5_K with sum of q8_1 values
     }
 
     return d*sumf_d - dmin*sumf_m;
@@ -2306,7 +2354,10 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
+    // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per
+    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+    //       is better amortized.
+    mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2315,7 +2366,10 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
+    // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per
+    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+    //       is better amortized.
+    mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }