]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: mul_mat_vec_q kernels for k-quants (#2203)
authorJohannes Gäßler <redacted>
Fri, 14 Jul 2023 17:44:08 +0000 (19:44 +0200)
committerGitHub <redacted>
Fri, 14 Jul 2023 17:44:08 +0000 (19:44 +0200)
ggml-cuda.cu

index 920466aae72db8f01c44e759a3c7c6a0b856cf1f..4c9e21429e10f4ead6bdbb6261625b8cc9d84be0 100644 (file)
@@ -13,6 +13,8 @@
 #include "ggml-cuda.h"
 #include "ggml.h"
 
+#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
@@ -74,7 +76,7 @@ typedef void (*ggml_cuda_op_t)(
 
 #define QK4_0 32
 #define QR4_0 2
-#define QI4_0 4
+#define QI4_0 (QK4_0 / (4 * QR4_0))
 typedef struct {
     half    d;              // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
@@ -83,7 +85,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
 
 #define QK4_1 32
 #define QR4_1 2
-#define QI4_1 4
+#define QI4_1 (QK4_1 / (4 * QR4_1))
 typedef struct {
     half    d;              // delta
     half    m;              // min
@@ -93,7 +95,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong
 
 #define QK5_0 32
 #define QR5_0 2
-#define QI5_0 4
+#define QI5_0 (QK5_0 / (4 * QR5_0))
 typedef struct {
     half d;                 // delta
     uint8_t qh[4];          // 5-th bit of quants
@@ -103,7 +105,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
 
 #define QK5_1 32
 #define QR5_1 2
-#define QI5_1 4
+#define QI5_1 (QK5_1 / (4 * QR5_1))
 typedef struct {
     half d;                 // delta
     half m;                 // min
@@ -114,7 +116,7 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 
 #define QK8_0 32
 #define QR8_0 1
-#define QI8_0 8
+#define QI8_0 (QK8_0 / (4 * QR8_0))
 typedef struct {
     half    d;              // delta
     int8_t  qs[QK8_0];      // quants
@@ -123,7 +125,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
 
 #define QK8_1 32
 #define QR8_1 1
-#define QI8_1 8
+#define QI8_1 (QK8_1 / (4 * QR8_1))
 typedef struct {
     half    d;              // delta
     half    s;              // unquantized sum
@@ -143,6 +145,8 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
 #define K_SCALE_SIZE 12
 #endif
 
+#define QR2_K 4
+#define QI2_K (QK_K / (4*QR2_K))
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
@@ -151,6 +155,8 @@ typedef struct {
 } block_q2_K;
 static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
 
+#define QR3_K 4
+#define QI3_K (QK_K / (4*QR3_K))
 typedef struct {
     uint8_t hmask[QK_K/8];     // quants - high bit
     uint8_t qs[QK_K/4];        // quants - low 2 bits
@@ -163,6 +169,8 @@ typedef struct {
 } block_q3_K;
 //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
 
+#define QR4_K 2
+#define QI4_K (QK_K / (4*QR4_K))
 #ifdef GGML_QKK_64
 typedef struct {
     half    d[2];              // super-block scales/mins
@@ -180,6 +188,8 @@ typedef struct {
 static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
 #endif
 
+#define QR5_K 2
+#define QI5_K (QK_K / (4*QR5_K))
 #ifdef GGML_QKK_64
 typedef struct {
     half d;                  // super-block scale
@@ -199,6 +209,8 @@ typedef struct {
 static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
 #endif
 
+#define QR6_K 2
+#define QI6_K (QK_K / (4*QR6_K))
 typedef struct {
     uint8_t ql[QK_K/2];   // quants, lower 4 bits
     uint8_t qh[QK_K/4];   // quants, upper 2 bits
@@ -1271,8 +1283,9 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
     y[iybs + iqs + y_offset] = v.y;
 }
 
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
 
     int vi;
@@ -1293,11 +1306,12 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restric
     return sumi*d;
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
 
     const int vi  = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1318,11 +1332,12 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric
     return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
 
     int qs;
@@ -1353,11 +1368,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restric
     return sumi*d;
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
 
     const int qs  = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1387,11 +1403,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restric
     return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
 
     int vi;
@@ -1406,7 +1423,220 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restric
     return sumi*d;
 #else
     return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 610
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+
+    const int bq8_offset = QR2_K * (iqs / QI8_1);
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    const float    d = bq2_K->d;
+    const float dmin = bq2_K->dmin;
+
+    const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]);
+
+    for (int i = 0; i < QR2_K; ++i) {
+        const int sc = bq2_K->scales[scale_offset + 2*i];
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const float d8i = bq8i->d;
+
+        const int vi = (v >> (2*i)) & 0x03030303;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+
+        sumf_d += d8i * (__dp4a(vi,         ui, 0) * (sc & 0xF)); // SIMD dot product
+        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >>  4)); // multiply constant q2_K part with sum of q8_1 values
+    }
+
+    return d*sumf_d - dmin*sumf_m;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    float sumf = 0.0f;
+
+    const float d = bq3_K->d;
+
+    int vl;
+    memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int));
+
+    int vh;
+    memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int));
+    vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted
+    vh >>= bq8_offset;
+
+    for (int i = 0; i < QR3_K; ++i) {
+        const int isc = scale_offset + 2*i;
+
+        const int isc_low = isc % (QK_K/32);
+        const int sc_shift_low = 4 * (isc / (QK_K/32));
+        const int sc_low  = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF;
+
+        const int isc_high = isc % (QK_K/64);
+        const int sc_shift_high = 2 * (isc / (QK_K/64));
+        const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+        const int sc = (sc_low | sc_high) - 32;
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+        const float d8i = bq8i->d;
+
+        const int vil = (vl >> (2*i)) & 0x03030303;
+
+        const int vih = ((vh >> i) << 2) & 0x04040404;
+
+        const int vi = __vsubss4(vil, vih);
+
+        sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+    const int bq8_offset = QR4_K * (iqs / QI8_1);
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    const float    d = bq4_K->d;
+    const float dmin = bq4_K->dmin;
+
+    const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
+
+    for (int i = 0; i < QR4_K; ++i) {
+        const int isc = bq8_offset + i;
+
+        uint8_t sc, m;
+        get_scale_min_k4(isc, bq4_K->scales, sc, m);
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+        const float d8i = bq8i->d;
+
+        const int vi = (v >> (4*i)) & 0x0F0F0F0F;
+
+        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
+        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q4_K with sum of q8_1 values
+    }
+
+    return d*sumf_d - dmin*sumf_m;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+    const int bq8_offset = QR5_K * (iqs / QI8_1);
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    const float    d = bq5_K->d;
+    const float dmin = bq5_K->dmin;
+
+    const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
+
+    const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
+
+    for (int i = 0; i < QR5_K; ++i) {
+        const int isc = bq8_offset + i;
+
+        uint8_t sc, m;
+        get_scale_min_k4(isc, bq5_K->scales, sc, m);
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+        const float d8i = bq8i->d;
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> i) << 4) & 0x10101010;
+
+        const int vi = vil | vih;
+
+        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
+        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q5_K with sum of q8_1 values
+    }
+
+    return d*sumf_d - dmin*sumf_m;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+
+    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+    float sumf = 0.0f;
+
+    const float d = bq6_K->d;
+
+    int vl;
+    memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int));
+
+    int vh;
+    memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int));
+
+    for (int i = 0; i < QR6_K; ++i) {
+        const int sc = bq6_K->scales[scale_offset + 4*i];
+
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i;
+        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]);
+        const float d8i = bq8i->d;
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030;
+
+        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+        sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
@@ -1429,7 +1659,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
     for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
         const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
 
-        const int iby = i + threadIdx.x / qi; // y block index
+        const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx
 
         const int iqs  = threadIdx.x % qi; // x block quant index when casting the quants to int
 
@@ -1962,7 +2192,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
 }
 
 static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK4_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1971,7 +2201,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK4_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1980,7 +2210,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK5_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1989,7 +2219,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK5_1 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -1998,7 +2228,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
 }
 
 static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % QK8_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -2006,6 +2236,51 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
+static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, vec_dot_q2_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, vec_dot_q3_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, vec_dot_q6_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
 static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -2494,13 +2769,22 @@ inline void ggml_cuda_op_mul_mat_vec(
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 
-    const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 ||
+    bool mul_mat_vec_q_implemented =
+        src0->type == GGML_TYPE_Q4_0 ||
         src0->type == GGML_TYPE_Q4_1 ||
         src0->type == GGML_TYPE_Q5_0 ||
         src0->type == GGML_TYPE_Q5_1 ||
         src0->type == GGML_TYPE_Q8_0;
-
-    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 610 && mul_mat_vec_q_implemented;
+#if QK_K == 256
+    mul_mat_vec_q_implemented = mul_mat_vec_q_implemented ||
+        src0->type == GGML_TYPE_Q2_K ||
+        src0->type == GGML_TYPE_Q3_K ||
+        src0->type == GGML_TYPE_Q4_K ||
+        src0->type == GGML_TYPE_Q5_K ||
+        src0->type == GGML_TYPE_Q6_K;
+#endif // QK_K == 256
+
+    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
 #endif
 
     if (use_mul_mat_vec_q) {
@@ -2526,6 +2810,21 @@ inline void ggml_cuda_op_mul_mat_vec(
             case GGML_TYPE_Q8_0:
                 mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
                 break;
+            case GGML_TYPE_Q2_K:
+                mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q3_K:
+                mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q4_K:
+                mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q5_K:
+                mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
+            case GGML_TYPE_Q6_K:
+                mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+                break;
             default:
                 GGML_ASSERT(false);
                 break;