]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : faster k-quants on older GPUs (#1930)
authorKawrakow <redacted>
Mon, 19 Jun 2023 15:14:09 +0000 (18:14 +0300)
committerGitHub <redacted>
Mon, 19 Jun 2023 15:14:09 +0000 (18:14 +0300)
* k_quants: hopefully much faster Q4_K on older GPUs

On the GTX-1660 that I have available to represent
"old GPUs", token prediction drops from 65.5 ms/tok
to 41.5 ms/tok!

* k_quants: hopefully much faster Q3_K on older GPUs

On the GTX-1660 that I have available to represent
"old GPUs", token prediction drops from 60.3 ms/tok
to 41.0 ms/tok!

* k_quants: faster Q2_K on older GPUs

It looks like I didn't need to change anything
compared to what we already had, so this is just
adding clarifying comments. But I now measure
36.3 ms/tok on the GTX-1660, instead fo the
47.2 ms/tok that I have written in the faster
k-quants PR.

* k_quants: faster Q5_K on older GPUs

68.5 ms/tok -> 62.0 ms/tok on GTX-1660.
For some reason the same access pattern that leads
to such resounding success for Q2_K to Q4_K did not
work at all for Q5_K.

It is also more difficult to measure because for Q5_K_S
we only have 32 layers on the GTX-1660, so output, tok embeddings
and kv cache are done on the CPU.

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-cuda.cu

index 9ebc57aff25d6b960e66c1112a2f0681d75cf48e..36a251ecce973bddee7df398cd61931fbed22cf9 100644 (file)
@@ -515,15 +515,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
 
     const block_q2_K * x = (const block_q2_K *)vx + ib0;
 
-    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31
-    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
     const int step = 16/K_QUANTS_PER_ITERATION;
 
-    const int im = tid/step;      // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - step*im; // 0...7
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
 
-    const int l0 = K_QUANTS_PER_ITERATION*in;        // 0...14 in steps of 4
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
     const int q_offset = 32*im + l0;
     const int s_offset = 8*im;
     const int y_offset = 128*im + l0;
@@ -578,27 +578,30 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
     }
 }
 
-static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) {
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
 
-    const int row = blockIdx.x;
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
     const block_q3_K * x = (const block_q3_K *)vx + ib0;
 
-    const int tid = threadIdx.x/2;  // 0...15
-    const int ix  = threadIdx.x%2;  // 0, 1
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    const int n  = 2;           // iterations in the inner loop
-    const int im = tid/8;       // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - 8*im;  // 0...7
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
 
     const uint8_t m = 1 << (4*im);
 
-    const int l0 = n*in;        // 0...28 in steps of 4
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
     const int q_offset =  32*im + l0;
     const int y_offset = 128*im + l0;
 
@@ -609,7 +612,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 
     float tmp = 0; // partial sum for thread in warp
 
-    for (int i = ix; i < num_blocks_per_row; i += 2) {
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
         const float   * y  = yy + i * QK_K + y_offset;
         const uint8_t * q = x[i].qs + q_offset;
@@ -650,22 +653,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
     }
 }
 
-static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) {
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int row = blockIdx.x;
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
-    const int tid = threadIdx.x/2;  // 0...15
-    const int ix  = threadIdx.x%2;
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
+    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
+
+    const int il  = tid/step;                            // 0...3
+    const int ir  = tid - step*il;                       // 0...7 or 0...3
+    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
 
     const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
     const int in = il%2;
@@ -681,7 +687,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
 
     float tmp = 0; // partial sum for thread in warp
 
-    for (int i = ix; i < num_blocks_per_row; i += 2) {
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
         const uint8_t * q1 = x[i].qs + q_offset;
         const uint8_t * q2 = q1 + 64;
@@ -736,7 +742,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
 
     const int il  = tid/4;     // 0...3
     const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
+    const int n   = 2;
 
     const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
     const int in = il%2;
@@ -775,11 +781,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
         float4 sum = {0.f, 0.f, 0.f, 0.f};
         float smin = 0;
         for (int l = 0; l < n; ++l) {
-            sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0));
-            sum.y += y1[l+32] * ((ql1[l] >>  4) + (qh[l] & (hm1 << 1) ? 16 : 0));
-            sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0));
-            sum.w += y2[l+32] * ((ql2[l] >>  4) + (qh[l] & (hm2 << 1) ? 16 : 0));
-            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
         }
         tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
 
@@ -1311,7 +1322,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
 
 static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2;
+    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
     const int block_num_y = (nrows + ny - 1) / ny;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(32, ny, 1);
@@ -1320,14 +1331,20 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
 
 static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const dim3 block_dims(32, 1, 1);
-    dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
-    const dim3 block_dims(32, 1, 1);
-    dequantize_mul_mat_vec_q4_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
 static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {