]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda : move row numbers to x grid dim in mmv kernels (#3921)
authorslaren <redacted>
Fri, 3 Nov 2023 11:13:09 +0000 (12:13 +0100)
committerGitHub <redacted>
Fri, 3 Nov 2023 11:13:09 +0000 (12:13 +0100)
ggml-cuda.cu

index baf02df2b229489333aca91f6fca68ed97ea0d9c..bdbcca0cabb88e4d26cb2d168a89e8def07d7491 100644 (file)
@@ -989,7 +989,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -1093,7 +1093,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 
 static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -1197,7 +1197,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
 
 static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
@@ -1451,7 +1451,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 
     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
 
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
     const int num_blocks_per_row = ncols / QK_K;
@@ -4261,7 +4261,7 @@ template <bool need_check> static __global__ void
 
 template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
         return;
@@ -4301,7 +4301,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
 static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
     // qk = quantized weights per x block
     // qr = number of quantized weights per data value in x block
-    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
         return;
@@ -4874,7 +4874,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4883,7 +4884,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4892,7 +4893,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4901,7 +4902,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4910,7 +4911,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
 static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4920,7 +4921,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4929,7 +4930,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4938,7 +4939,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4953,7 +4954,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
     GGML_ASSERT(ncols % QK_K == 0);
     const int ny = 2 / K_QUANTS_PER_ITERATION;
     const int block_num_y = (nrows + ny - 1) / ny;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(32, ny, 1);
     dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
@@ -4961,7 +4962,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
 static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK4_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4970,7 +4971,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK4_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4979,7 +4980,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK5_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4988,7 +4989,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK5_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4997,7 +4998,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK8_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5006,7 +5007,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5015,7 +5016,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5024,7 +5025,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5033,7 +5034,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5042,7 +5043,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
 static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5061,7 +5062,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
     dequantize_mul_mat_vec<1, 1, convert_f16>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);