]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: mul_mat_vec_q for batch sizes > 1 (llama/5351)
authorJohannes Gäßler <redacted>
Tue, 6 Feb 2024 13:44:06 +0000 (14:44 +0100)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:47 +0000 (09:55 +0200)
ggml-cuda.cu

index 3242a0b4ad7bf8c0440810135ae66692747eb23f..95161b3f4b3d6c1c35e9f423a1a933b0735bec4f 100644 (file)
@@ -5310,41 +5310,50 @@ template <bool need_check> static __global__ void
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
-template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
-static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
+template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
+static __global__ void mul_mat_vec_q(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
+
+    const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
+
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
 
-    if (row >= nrows) {
+    if (row >= nrows_x) {
         return;
     }
 
-    const int blocks_per_row = ncols / qk;
+    const int blocks_per_row_x = ncols_x / qk;
+    const int blocks_per_col_y = nrows_y / QK8_1;
     const int blocks_per_warp = vdr * WARP_SIZE / qi;
 
 // partial sum for each thread
-    float tmp = 0.0f;
+    float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
 
     const block_q_t  * x = (const block_q_t  *) vx;
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
-    for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
+    for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row_x + i; // x block index
 
         const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 
         const int iqs  = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
 
-        tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
+#pragma unroll
+        for (int j = 0; j < ncols_y; ++j) {
+            tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
+        }
     }
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    for (int j = 0; j < ncols_y; ++j) {
+        tmp[j] = warp_reduce_sum(tmp[j]);
 
-    if (threadIdx.x == 0) {
-        dst[row] = tmp;
+        if (threadIdx.x == 0) {
+            dst[j*nrows_x + row] = tmp[j];
+        }
     }
 }
 
@@ -6816,121 +6825,56 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
-static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK4_0 == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK4_1 == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK5_0 == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK5_1 == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK8_0 == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
-
-static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
+template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
+static void mul_mat_vec_q_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
 
-static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
-}
+    GGML_ASSERT(ncols_x % qk == 0);
+    GGML_ASSERT(ncols_y <= 8);
 
-static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+    switch (ncols_y) {
+        case 1:
+            mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        case 2:
+            mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        case 3:
+            mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        case 4:
+            mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        case 5:
+            mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        case 6:
+            mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        case 7:
+            mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        case 8:
+            mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+        default:
+            GGML_ASSERT(false);
+            // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
+            //     <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
+            break;
+    }
 }
 
 static void ggml_mul_mat_q4_0_q8_1_cuda(
@@ -8578,50 +8522,61 @@ static void ggml_cuda_op_mul_mat_vec_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream) {
 
-    GGML_ASSERT(ggml_nrows(src1) == 1);
-
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q4_1:
-            mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q5_0:
-            mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q5_1:
-            mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q8_0:
-            mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q2_K:
-            mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q3_K:
-            mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q4_K:
-            mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q5_K:
-            mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_Q6_K:
-            mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_IQ2_XXS:
-            mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_IQ2_XS:
-            mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         case GGML_TYPE_IQ3_XXS:
-            mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
+                (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
             break;
         default:
             GGML_ASSERT(false);
@@ -9945,17 +9900,18 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
 #ifdef GGML_CUDA_FORCE_DMMV
             const bool use_mul_mat_vec_q = false;
 #else
-            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
+            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
 #endif // GGML_CUDA_FORCE_DMMV
 
             if (use_mul_mat_vec_q) {
-                // NOTE: this kernel does not support ggml_nrows(src1) > 1
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
             } else {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
             }
         } else {
-            if (use_mul_mat_q) {
+            if (src1->ne[1] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
+            } else if (use_mul_mat_q) {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
             } else {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);