]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: faster non k-quant mul_mat_q kernels (#2483)
authorJohannes Gäßler <redacted>
Wed, 2 Aug 2023 16:04:04 +0000 (18:04 +0200)
committerGitHub <redacted>
Wed, 2 Aug 2023 16:04:04 +0000 (18:04 +0200)
ggml-cuda.cu

index a4dd6bb9df99769e90e3898423884383861ca984..e0192bc6ecebc6052330b0c209a2eb577c73eb5f 100644 (file)
@@ -1362,22 +1362,185 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
 }
 
 // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
 
-#define VDR_q4_0_q8_1 1
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ  4
 
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
-    const int & vi, const int & ui0, const int & ui1, const half & d4, const half2 & ds8) {
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+    const int * v, const int * u, const float & d4, const half2 & ds8) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    // subtract 8 from each quantized value
-    const int vi0 = (vi >> 0) & 0x0F0F0F0F;
-    const int vi1 = (vi >> 4) & 0x0F0F0F0F;
+    int sumi = 0;
 
-    // SIMD dot product of quantized values
-    int sumi = __dp4a(vi0, ui0, 0);
-    sumi     = __dp4a(vi1, ui1, sumi);
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = __dp4a(vi0, u[2*i+0], sumi);
+        sumi = __dp4a(vi1, u[2*i+1], sumi);
+    }
+
+    // second part effectively subtracts 8 from each quant value
+    return d4 * (sumi * __half2float(ds8.x) - (8*vdr/QI4_0) * __half2float(ds8.y));
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = __dp4a(vi0, u[2*i+0], sumi);
+        sumi = __dp4a(vi1, u[2*i+1], sumi);
+    }
+
+#ifdef GGML_CUDA_F16
+    const half2 tmp = __hmul2(dm4, ds8);
+    const float d4d8 = __half2float(tmp.x);
+    const float m4s8 = __half2float(tmp.y);
+#else
+    const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x);
+    const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y);
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
+
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+    // second part effectively subtracts 16 from each quant value
+    return d5 * (sumi*__half2float(ds8.x) - (16*vdr/QI5_0) * __half2float(ds8.y));
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
+
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+#ifdef GGML_CUDA_F16
+    const half2 tmp = __hmul2(dm5, ds8);
+    const float d5d8 = __half2float(tmp.x);
+    const float m5s8 = __half2float(tmp.y);
+#else
+    const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x);
+    const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y);
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
 
-    return __half2float(d4) * (sumi * __half2float(ds8.x) - (8/QI4_0) * __half2float(ds8.y));
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
+    const int * v, const int * u, const float & d8_0, const half2 & ds8_1) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
+
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = __dp4a(v[i], u[i], sumi);
+    }
+
+    return sumi * d8_0 * __half2float(ds8_1.x);
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
+
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = __dp4a(v[i], u[i], sumi);
+    }
+
+#ifdef GGML_CUDA_F16
+    const half2 tmp = __hmul2(dm8, ds8);
+    const float d8d8 = __half2float(tmp.x);
+    const float m8s8 = __half2float(tmp.y);
+#else
+    const float d8d8 = __half2float(dm8.x) * __half2float(ds8.x);
+    const float m8s8 = __half2float(dm8.y) * __half2float(ds8.y);
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1388,20 +1551,26 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
 
     const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
 
-    const int vi  = get_int_from_uint8(bq4_0->qs, iqs);
-    const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
-    const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_0);
+    int v[VDR_Q4_0_Q8_1_MMVQ];
+    int u[2*VDR_Q4_0_Q8_1_MMVQ];
 
-    return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds);
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+        v[i]     = get_int_from_uint8(bq4_0->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
+    }
+
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
     __shared__ int  tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0) + GGML_CUDA_MMQ_Y/QI4_0];
+    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0) + GGML_CUDA_MMQ_Y/QI4_0];
 
     *x_ql = tile_x_qs;
-    *x_dm = tile_x_d;
+    *x_dm = (half2 *) tile_x_d;
 }
 
 template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
@@ -1418,6 +1587,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
 
     const block_q4_0 * bx0 = (block_q4_0 *) vx;
 
+    float * x_dmf = (float *) x_dm;
+
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
         int i = i0 + i_offset;
@@ -1429,7 +1600,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
         const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
         x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
-        x_dm[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx].x = bxi->d;
+        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
     }
 
 //     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -1462,39 +1633,19 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
     __builtin_assume(k <  WARP_SIZE);
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const float * x_dmf = (float *) x_dm;
 
-    return vec_dot_q4_0_q8_1_impl(
-        x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
-        x_dm[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
-}
-
-#define VDR_q4_1_q8_1 1
-
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
-    const int & vi, const int & ui0, const int & ui1, const half2 & dm4, const half2 & ds8) {
+    int u[2*VDR_Q4_0_Q8_1_MMQ];
 
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const int vi0 = (vi >> 0) & 0x0F0F0F0F;
-    const int vi1 = (vi >> 4) & 0x0F0F0F0F;
-
-    // SIMD dot product of quantized values
-    int sumi = __dp4a(vi0, ui0, 0);
-    sumi     = __dp4a(vi1, ui1, sumi);
-
-#ifdef GGML_CUDA_F16
-    const half2 tmp = __hmul2(dm4, ds8);
-    const float d4d8 = __half2float(tmp.x);
-    const float m4s8 = __half2float(tmp.y);
-#else
-    const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x);
-    const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y);
-#endif // GGML_CUDA_F16
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_0];
+    }
 
-    // scale second part of sum by QI8_1/QR4_1 to compensate for multiple threads adding it
-    return sumi * d4d8 + m4s8 / (QI8_1 / QR4_1);
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
+         y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
@@ -1502,11 +1653,17 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
 
     const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
 
-    const int vi = get_int_from_uint8_aligned(bq4_1->qs, iqs);
-    const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
-    const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_1);
+    int v[VDR_Q4_1_Q8_1_MMVQ];
+    int u[2*VDR_Q4_1_Q8_1_MMVQ];
 
-    return vec_dot_q4_1_q8_1_impl(vi, ui0, ui1, bq4_1->dm, bq8_1->ds);
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+        v[i]    = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
+    }
+
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -1575,35 +1732,17 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 
-    return vec_dot_q4_1_q8_1_impl(
-        x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
-        x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
-}
-
-#define VDR_q5_0_q8_1 1
+    int u[2*VDR_Q4_1_Q8_1_MMQ];
 
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
-    const int & qs, const int & qh, const int & ui0, const int & ui1, const half & d5, const half2 & ds8) {
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_1];
+    }
 
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    int vi0 = (qs >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
-    vi0    |= (qh <<  4) & 0x00000010; // 0 ->  4
-    vi0    |= (qh << 11) & 0x00001000; // 1 -> 12
-    vi0    |= (qh << 18) & 0x00100000; // 2 -> 20
-    vi0    |= (qh << 25) & 0x10000000; // 3 -> 28
-    int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
-
-    int vi1 = (qs >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
-    vi1    |= (qh >> 12) & 0x00000010; // 16 ->  4
-    vi1    |= (qh >>  5) & 0x00001000; // 17 -> 12
-    vi1    |= (qh <<  2) & 0x00100000; // 18 -> 20
-    vi1    |= (qh <<  9) & 0x10000000; // 19 -> 28
-    sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
-
-    return __half2float(d5) * (sumi*__half2float(ds8.x) - (16/QI5_0) * __half2float(ds8.y));
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
+         y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
@@ -1611,23 +1750,28 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
 
     const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
 
-    const int qs = get_int_from_uint8(bq5_0->qs, iqs);
-    const int qh = get_int_from_uint8(bq5_0->qh, 0) >> (4 * iqs);
-    const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
-    const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_0);
+    int vl[VDR_Q5_0_Q8_1_MMVQ];
+    int vh[VDR_Q5_0_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_0_Q8_1_MMVQ];
 
-    return vec_dot_q5_0_q8_1_impl(qs, qh, ui0, ui1, bq5_0->d, bq8_1->ds);
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+        vl[i]    = get_int_from_uint8(bq5_0->qs, iqs + i);
+        vh[i]    = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
+    }
+
+    return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int  tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ int  tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
-    __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
+    __shared__ int  tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
+    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
 
     *x_ql = tile_x_ql;
-    *x_qh = tile_x_qh;
-    *x_dm = tile_x_d;
+    *x_dm = (half2 *) tile_x_d;
 }
 
 template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
@@ -1654,11 +1798,31 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
 
         const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+        const int ql = get_int_from_uint8(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
     const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_0) {
@@ -1670,8 +1834,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
 
         const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
 
-        x_qh[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd]   = get_int_from_uint8(bxi->qh, 0);
-        x_dm[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd].x = bxi->d;
+        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
     }
 }
 
@@ -1688,46 +1851,18 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
+    const float * x_dmf = (float *) x_dm;
 
-    return vec_dot_q5_0_q8_1_impl(
-        x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_0)), y_qs[j * (2*WARP_SIZE) + kyqs],
-        y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
-}
-
-#define VDR_q5_1_q8_1 1
-
-static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
-    const int & qs, const int & qh, const int & ui0, const int & ui1, const half2 & dm5, const half2 & ds8) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    int vi0 = (qs >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
-    vi0    |= (qh <<  4) & 0x00000010; // 0 ->  4
-    vi0    |= (qh << 11) & 0x00001000; // 1 -> 12
-    vi0    |= (qh << 18) & 0x00100000; // 2 -> 20
-    vi0    |= (qh << 25) & 0x10000000; // 3 -> 28
-    int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
-
-    int vi1 = (qs >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
-    vi1    |= (qh >> 12) & 0x00000010; // 16 ->  4
-    vi1    |= (qh >>  5) & 0x00001000; // 17 -> 12
-    vi1    |= (qh <<  2) & 0x00100000; // 18 -> 20
-    vi1    |= (qh <<  9) & 0x10000000; // 19 -> 28
-    sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
-
-#ifdef GGML_CUDA_F16
-    const half2 tmp = __hmul2(dm5, ds8);
-    const float d5d8 = __half2float(tmp.x);
-    const float m5s8 = __half2float(tmp.y);
-#else
-    const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x);
-    const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y);
-#endif // GGML_CUDA_F16
+    int u[2*VDR_Q5_0_Q8_1_MMQ];
 
-    return sumi*d5d8 + m5s8/QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_0];
+    }
 
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
@@ -1735,22 +1870,27 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
 
     const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
 
-    const int qs = get_int_from_uint8_aligned(bq5_1->qs, iqs);
-    const int qh = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * iqs);
-    const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
-    const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_1);
+    int vl[VDR_Q5_1_Q8_1_MMVQ];
+    int vh[VDR_Q5_1_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+        vl[i]   = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
+        vh[i]   = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
+    }
 
-    return vec_dot_q5_1_q8_1_impl(qs, qh, ui0, ui1, bq5_1->dm, bq8_1->ds);
+    return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE )      + GGML_CUDA_MMQ_Y];
-    __shared__ int   tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1) + GGML_CUDA_MMQ_Y/QI5_1];
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
     __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1) + GGML_CUDA_MMQ_Y/QI5_1];
 
     *x_ql = tile_x_ql;
-    *x_qh = tile_x_qh;
     *x_dm = tile_x_dm;
 }
 
@@ -1778,7 +1918,24 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
 
         const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -1794,7 +1951,6 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
 
         const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
 
-        x_qh[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = get_int_from_uint8_aligned(bxi->qh, 0);
         x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
     }
 }
@@ -1813,24 +1969,16 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
 
-    return vec_dot_q5_1_q8_1_impl(
-        x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_1)), y_qs[j * (2*WARP_SIZE) + kyqs],
-        y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
-}
-
-#define VDR_q8_0_q8_1 1
+    int u[2*VDR_Q5_1_Q8_1_MMQ];
 
-static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
-    const int & vi, const int & ui, const half & d8_0, const half2 & ds8_1) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    // SIMD dot product of quantized values
-    const int sumi = __dp4a(vi, ui, 0);
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_1];
+    }
 
-    return sumi * __half2float(d8_0) * __half2float(ds8_1.x);
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
@@ -1838,19 +1986,24 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
 
     const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
 
-    const int vi = get_int_from_int8(bq8_0->qs, iqs);
-    const int ui = get_int_from_int8_aligned(bq8_1->qs, iqs);
+    int v[VDR_Q8_0_Q8_1_MMVQ];
+    int u[VDR_Q8_0_Q8_1_MMVQ];
 
-    return vec_dot_q8_0_q8_1_impl(vi, ui, bq8_0->d, bq8_1->ds);
+    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+        v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
+        u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+    }
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
     __shared__ int  tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0) + GGML_CUDA_MMQ_Y/QI8_0];
+    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0) + GGML_CUDA_MMQ_Y/QI8_0];
 
     *x_ql = tile_x_qs;
-    *x_dm = tile_x_d;
+    *x_dm = (half2 *) tile_x_d;
 }
 
 template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
@@ -1864,6 +2017,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_
 
     const int kbx  = k / QI8_0;
     const int kqsx = k % QI8_0;
+    float * x_dmf = (float *) x_dm;
 
     const block_q8_0 * bx0 = (block_q8_0 *) vx;
 
@@ -1878,7 +2032,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_
         const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
         x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
-        x_dm[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbx].x = bxi->d;
+        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbx] = bxi->d;
     }
 
 //     const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
@@ -1912,9 +2066,11 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
-    return vec_dot_q8_0_q8_1_impl(
-        x_ql[i * (WARP_SIZE + 1) + k], y_qs[j*WARP_SIZE + k],
-        x_dm[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0].x, y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
+    const float * x_dmf = (float *) x_dm;
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
+         y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
 }
 
 #define VDR_q2_K_q8_1 1
@@ -2288,15 +2444,15 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     int    u[2*QR4_K];
     float d8[QR4_K];
 
-    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
-    const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
+    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
 
     // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
     // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
     // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
     // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
 
-    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
     v[0] = q4[0];
     v[1] = q4[4];
 
@@ -2317,7 +2473,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
         d8[i] = bq8i->ds.x;
 
-        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
         u[2*i+0] = q8[0];
         u[2*i+1] = q8[4];
     }
@@ -2345,12 +2501,12 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     const float d8_1 = bq8_1[0].ds.x;
     const float d8_2 = bq8_1[1].ds.x;
 
-    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
-    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
-    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
-    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
 
-    const int * q4 = (const int *)bq4_K->qs + iqs;
+    const int * q4 = (const int *)bq4_K->qs + (iqs/2);
     const int v1 = q4[0];
     const int v2 = q4[4];
 
@@ -2457,11 +2613,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
     int    u[2*QR4_K];
     float d8[QR4_K];
 
-    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
-    const int bq8_offset = QR4_K * (kqsx / (QI8_1/2));
+    // kqsx is in 0,2...30. bq8_offset = 2 * (kqsx/4) -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * ((kqsx/2) / (QI8_1/2));
 
-    v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
-    v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
+    v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 0];
+    v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 4];
 
     const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4];
     uint16_t aux[2];
@@ -2477,7 +2633,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
     const uint8_t * m  = sc + 2;
 
     for (int l = 0; l < QR4_K; ++l) {
-        const int kqsy = j * (QR4_K*WARP_SIZE) + kbx * (QR4_K*QI4_K) + (bq8_offset + l) * QI8_1 + kqsx % (QI8_1/2);
+        const int kqsy = j * (QR4_K*WARP_SIZE) + kbx * (QR4_K*QI4_K) + (bq8_offset + l) * QI8_1 + (kqsx/2) % (QI8_1/2);
         u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)];
         u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)];
         d8[l] = y_ds[kqsy / QI8_1].x;
@@ -2532,9 +2688,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     int    u[2*QR5_K];
     float d8[QR5_K];
 
-    const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
-    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
-    const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
+    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
 
     vl[0] = ql[0];
     vl[1] = ql[4];
@@ -2559,7 +2715,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
         d8[i] = bq8i->ds.x;
 
-        const int * q8 = (const int *)bq8i->qs + (iqs%4);
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
         u[2*i+0] = q8[0];
         u[2*i+1] = q8[4];
     }
@@ -2578,17 +2734,17 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     const float d8_1 = bq8_1[0].ds.x;
     const float d8_2 = bq8_1[1].ds.x;
 
-    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
-    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
-    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
-    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
 
-    const int * ql = (const int *)bq5_K->qs + iqs;
+    const int * ql = (const int *)bq5_K->qs + (iqs/2);
     const int vl1 = ql[0];
     const int vl2 = ql[4];
 
-    const int step = 4 * iqs; // 0, 4, 8, 12
-    const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
+    const int step = 4 * (iqs/2); // 0, 4, 8, 12
+    const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
     const int in = step%8; // 0, 4, 0, 4
     const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
 
@@ -2711,13 +2867,13 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
     int    u[2*QR4_K];
     float d8[QR4_K];
 
-    const int bq8_offset = QR5_K * (kqsx / (QI8_1/2));
+    const int bq8_offset = QR5_K * ((kqsx/2) / (QI8_1/2));
 
-    vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
-    vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
+    vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 0];
+    vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 4];
 
-    vh[0] = x_qh[i * (WARP_SIZE/4) + i/4 + kqsx % 4 + 0] >> bq8_offset;
-    vh[1] = x_qh[i * (WARP_SIZE/4) + i/4 + kqsx % 4 + 4] >> bq8_offset;
+    vh[0] = x_qh[i * (WARP_SIZE/4) + i/4 + (kqsx/2) % 4 + 0] >> bq8_offset;
+    vh[1] = x_qh[i * (WARP_SIZE/4) + i/4 + (kqsx/2) % 4 + 4] >> bq8_offset;
 
     const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4];
     uint16_t aux[2];
@@ -2733,7 +2889,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
     const uint8_t * m  = sc + 2;
 
     for (int l = 0; l < QR5_K; ++l) {
-        const int kqsy = j * (QR5_K*WARP_SIZE) + kbx * (QR5_K*QI5_K) + (bq8_offset + l) * QI8_1 + kqsx % (QI8_1/2);
+        const int kqsy = j * (QR5_K*WARP_SIZE) + kbx * (QR5_K*QI5_K) + (bq8_offset + l) * QI8_1 + (kqsx/2) % (QI8_1/2);
         u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)];
         u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)];
         d8[l] = y_ds[kqsy / QI8_1].x;
@@ -2982,7 +3138,7 @@ static __global__ void mul_mat_q(
 #if __CUDA_ARCH__ >= 700 // Unrolling the loop is slower on Pascal
 #pragma unroll
 #endif // __CUDA_ARCH__ >= 700
-        for (int k = 0; k < WARP_SIZE/vdr; ++k) {
+        for (int k = 0; k < WARP_SIZE; k += vdr) {
 #pragma unroll
             for (int j = 0; j < WARP_SIZE; j += 8) {
 #pragma unroll
@@ -3034,9 +3190,9 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
     for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
         const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index
 
-        const int iby = (i + threadIdx.x / (qi/vdr)) * qk/QK8_1; // y block index that aligns with ibx
+        const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx
 
-        const int iqs  = threadIdx.x % (qi/vdr); // x block quant index when casting the quants to int
+        const int iqs  = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
 
         tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
     }
@@ -3579,7 +3735,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1>
+    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3588,7 +3744,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1>
+    mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3597,7 +3753,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1>
+    mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3606,7 +3762,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1>
+    mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3615,7 +3771,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1>
+    mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3717,10 +3873,10 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
+        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
+        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3735,10 +3891,10 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1_mul_mat>
+        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1_mul_mat>
+        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3753,10 +3909,10 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1_mul_mat>
+        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1_mul_mat>
+        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3771,10 +3927,10 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1_mul_mat>
+        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1_mul_mat>
+        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3789,10 +3945,10 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1_mul_mat>
+        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1_mul_mat>
+        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }