]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : sync latest llama.cpp (control DMMV X/Y sizes)
authorGeorgi Gerganov <redacted>
Sat, 27 May 2023 13:20:24 +0000 (16:20 +0300)
committerGeorgi Gerganov <redacted>
Sat, 27 May 2023 13:20:24 +0000 (16:20 +0300)
src/ggml-cuda.cu

index 35d2e457cbf8487fd967b6459925a85d1ab0282b..98170a3ae17de41c54e36a20e37024ea8f38f51c 100644 (file)
@@ -83,9 +83,19 @@ typedef struct {
 } block_q8_0;
 static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
+#define WARP_SIZE 32
+
 #define CUDA_MUL_BLOCK_SIZE 256
+
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
-#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
+
+// dmmv = dequantize_mul_mat_vec
+#ifndef GGML_CUDA_DMMV_X
+#define GGML_CUDA_DMMV_X 32
+#endif
+#ifndef GGML_CUDA_DMMV_Y
+#define GGML_CUDA_DMMV_Y 1
+#endif
 
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
     dequantize_kernel(vx, ib, iqs, v0, v1);
 }
 
-template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
-    const int row = blockIdx.x;
+    // qk = quantized weights per x block
+    // qr = number of quantized weights per data value in x block
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
+    const int iter_stride = 2*GGML_CUDA_DMMV_X;
+    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
     const int y_offset = qr == 1 ? 1 : qk/2;
 
-    __shared__ float tmp[block_size]; // separate sum for each thread
-    tmp[tid] = 0;
+    float tmp = 0; // partial sum for thread in warp
 
-    for (int i = 0; i < ncols/block_size; i += 2) {
-        const int col = i*block_size + 2*tid;
-        const int ib = (row*ncols + col)/qk; // block index
-        const int iqs = (col%qk)/qr; // quant index
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = (row*ncols + col)/qk; // block index
+        const int iqs = (col%qk)/qr; // quant index
         const int iybs = col - col%qk; // y block start index
 
-        // dequantize
-        float v0, v1;
-        dequantize_kernel(vx, ib, iqs, v0, v1);
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            float v0, v1;
+            dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
 
-        // matrix multiplication
-        tmp[tid] += v0 * y[iybs + iqs + 0];
-        tmp[tid] += v1 * y[iybs + iqs + y_offset];
+            // matrix multiplication
+            tmp += v0 * y[iybs + iqs + j/qr + 0];
+            tmp += v1 * y[iybs + iqs + j/qr + y_offset];
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+        }
     }
 
     // sum up partial sums and write back result
     __syncthreads();
-    for (int s=block_size/2; s>0; s>>=1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
+
     if (tid == 0) {
-        dst[row] = tmp[0];
+        dst[row] = tmp;
     }
 }
 
@@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
 }
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
+        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
+        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
+        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
+        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
+        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
 }
 
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
-    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
-        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<1, 1, convert_f16>
+        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
 }
 
 static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {