]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: mul_mat_vec_q max. batch size 8 -> 4 (llama/5370)
authorJohannes Gäßler <redacted>
Tue, 6 Feb 2024 17:43:06 +0000 (18:43 +0100)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:47 +0000 (09:55 +0200)
ggml-cuda.cu

index 95161b3f4b3d6c1c35e9f423a1a933b0735bec4f..3b828375e8adbe590f4df573285ad5973d871180 100644 (file)
@@ -6831,7 +6831,7 @@ static void mul_mat_vec_q_cuda(
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % qk == 0);
-    GGML_ASSERT(ncols_y <= 8);
+    GGML_ASSERT(ncols_y <= 4);
 
     const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
@@ -6853,22 +6853,22 @@ static void mul_mat_vec_q_cuda(
             mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
             break;
-        case 5:
-            mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
-            break;
-        case 6:
-            mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
-            break;
-        case 7:
-            mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
-            break;
-        case 8:
-            mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
-            break;
+        // case 5:
+        //     mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //     break;
+        // case 6:
+        //     mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //     break;
+        // case 7:
+        //     mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //     break;
+        // case 8:
+        //     mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
+        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+        //     break;
         default:
             GGML_ASSERT(false);
             // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
@@ -9909,7 +9909,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
             }
         } else {
-            if (src1->ne[1] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
+            if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
             } else if (use_mul_mat_q) {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);