]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fixed mmvq kernel for bs 2,3,4 and -sm row (llama/5386)
authorJohannes Gäßler <redacted>
Wed, 7 Feb 2024 11:40:26 +0000 (12:40 +0100)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:47 +0000 (09:55 +0200)
ggml-cuda.cu

index 3b828375e8adbe590f4df573285ad5973d871180..db9da24594cb249c0dc8a38e88d7e74a379870f1 100644 (file)
@@ -5313,7 +5313,7 @@ template <bool need_check> static __global__ void
 template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
 
     const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
 
@@ -5352,7 +5352,7 @@ static __global__ void mul_mat_vec_q(
         tmp[j] = warp_reduce_sum(tmp[j]);
 
         if (threadIdx.x == 0) {
-            dst[j*nrows_x + row] = tmp[j];
+            dst[j*nrows_dst + row] = tmp[j];
         }
     }
 }
@@ -6828,7 +6828,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
 template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
 static void mul_mat_vec_q_cuda(
     const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % qk == 0);
     GGML_ASSERT(ncols_y <= 4);
@@ -6839,40 +6839,40 @@ static void mul_mat_vec_q_cuda(
     switch (ncols_y) {
         case 1:
             mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
             break;
         case 2:
             mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
             break;
         case 3:
             mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
             break;
         case 4:
             mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
             break;
         // case 5:
         //     mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
         //     break;
         // case 6:
         //     mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
         //     break;
         // case 7:
         //     mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
         //     break;
         // case 8:
         //     mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
         //     break;
         default:
             GGML_ASSERT(false);
             // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
-            //     <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            //     <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
             break;
     }
 }
@@ -8391,7 +8391,7 @@ static void ggml_cuda_op_mul_mat_q(
     CUDA_CHECK(cudaGetDevice(&id));
 
     // the main device has a larger memory buffer to hold the results from all GPUs
-    // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
+    // nrows_dst == nrows of the matrix that the kernel writes into
     const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
 
     switch (src0->type) {
@@ -8525,58 +8525,70 @@ static void ggml_cuda_op_mul_mat_vec_q(
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
 
+    const int64_t ne10 = src1->ne[0];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
+
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
             mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q4_1:
             mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_0:
             mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_1:
             mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q8_0:
             mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q3_K:
             mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q4_K:
             mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_K:
             mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_Q6_K:
             mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_IQ2_XXS:
             mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_IQ2_XS:
             mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         case GGML_TYPE_IQ3_XXS:
             mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
-                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         default:
             GGML_ASSERT(false);
@@ -9909,7 +9921,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
             }
         } else {
-            if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
+            if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
             } else if (use_mul_mat_q) {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);