]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: faster k-quant mul_mat_q kernels (#2525)
authorJohannes Gäßler <redacted>
Sat, 5 Aug 2023 16:20:44 +0000 (18:20 +0200)
committerGitHub <redacted>
Sat, 5 Aug 2023 16:20:44 +0000 (18:20 +0200)
ggml-cuda.cu

index d64d7045ccf3e50e4051898fd1a9a646bfe89379..9d42efb0d0b03a98b11c16f19260a10c8b9009da 100644 (file)
@@ -1383,8 +1383,10 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
         sumi = __dp4a(vi1, u[2*i+1], sumi);
     }
 
+    const float2 ds8f = __half22float2(ds8);
+
     // second part effectively subtracts 8 from each quant value
-    return d4 * (sumi * __half2float(ds8.x) - (8*vdr/QI4_0) * __half2float(ds8.y));
+    return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1410,12 +1412,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
     }
 
 #ifdef GGML_CUDA_F16
-    const half2 tmp = __hmul2(dm4, ds8);
-    const float d4d8 = __half2float(tmp.x);
-    const float m4s8 = __half2float(tmp.y);
+    const float2 tmp = __half22float2(__hmul2(dm4, ds8));
+    const float d4d8 = tmp.x;
+    const float m4s8 = tmp.y;
 #else
-    const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x);
-    const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y);
+    const float2 dm4f = __half22float2(dm4);
+    const float2 ds8f = __half22float2(ds8);
+    const float d4d8 = dm4f.x * ds8f.x;
+    const float m4s8 = dm4f.y * ds8f.y;
 #endif // GGML_CUDA_F16
 
     // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
@@ -1434,6 +1438,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     int sumi = 0;
 
+#pragma unroll
     for (int i = 0; i < vdr; ++i) {
         int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
         vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
@@ -1450,8 +1455,10 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
         sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
     }
 
+    const float2 ds8f = __half22float2(ds8);
+
     // second part effectively subtracts 16 from each quant value
-    return d5 * (sumi*__half2float(ds8.x) - (16*vdr/QI5_0) * __half2float(ds8.y));
+    return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1466,6 +1473,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     int sumi = 0;
 
+#pragma unroll
     for (int i = 0; i < vdr; ++i) {
         int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
         vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
@@ -1483,12 +1491,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
     }
 
 #ifdef GGML_CUDA_F16
-    const half2 tmp = __hmul2(dm5, ds8);
-    const float d5d8 = __half2float(tmp.x);
-    const float m5s8 = __half2float(tmp.y);
+    const float2 tmp = __half22float2(__hmul2(dm5, ds8));
+    const float d5d8 = tmp.x;
+    const float m5s8 = tmp.y;
 #else
-    const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x);
-    const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y);
+    const float2 dm5f = __half22float2(dm5);
+    const float2 ds8f = __half22float2(ds8);
+    const float d5d8 = dm5f.x * ds8f.x;
+    const float m5s8 = dm5f.y * ds8f.y;
 #endif // GGML_CUDA_F16
 
     // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
@@ -1503,17 +1513,18 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
 #define VDR_Q8_0_Q8_1_MMQ 8
 
 template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
-    const int * v, const int * u, const float & d8_0, const half2 & ds8_1) {
+    const int * v, const int * u, const float & d8_0, const float & d8_1) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     int sumi = 0;
 
+#pragma unroll
     for (int i = 0; i < vdr; ++i) {
         // SIMD dot product of quantized values
         sumi = __dp4a(v[i], u[i], sumi);
     }
 
-    return sumi * d8_0 * __half2float(ds8_1.x);
+    return d8_0*d8_1 * sumi;
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1525,18 +1536,21 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     int sumi = 0;
 
+#pragma unroll
     for (int i = 0; i < vdr; ++i) {
         // SIMD dot product of quantized values
         sumi = __dp4a(v[i], u[i], sumi);
     }
 
 #ifdef GGML_CUDA_F16
-    const half2 tmp = __hmul2(dm8, ds8);
-    const float d8d8 = __half2float(tmp.x);
-    const float m8s8 = __half2float(tmp.y);
+    const float2 tmp = __half22float2(__hmul2(dm8, ds8));
+    const float d8d8 = tmp.x;
+    const float m8s8 = tmp.y;
 #else
-    const float d8d8 = __half2float(dm8.x) * __half2float(ds8.x);
-    const float m8s8 = __half2float(dm8.y) * __half2float(ds8.y);
+    const float2 dm8f = __half22float2(dm8);
+    const float2 ds8f = __half22float2(ds8);
+    const float d8d8 = dm8.x * ds8.x;
+    const float m8s8 = dm8.y * ds8.y;
 #endif // GGML_CUDA_F16
 
     // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
@@ -1546,6 +1560,312 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
+#define VDR_Q2_K_Q8_1_MMVQ 1
+#define VDR_Q2_K_Q8_1_MMQ  2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+    const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const half2 & dm2, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++i) {
+        const int sc = scales[2*i];
+
+        const int vi = (v >> (2*i)) & 0x03030303;
+
+        sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+        sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
+    }
+
+    const float2 dm2f = __half22float2(dm2);
+
+    return dm2f.x*sumf_d - dm2f.y*sumf_m;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const half2 & dm2, const float & d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi_d = 0;
+    int sumi_m = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
+        int sumi_d_sc = 0;
+
+        const int sc = scales[i0 / (QI8_1/2)];
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
+            sumi_m    = __dp4a(m,    u[i], sumi_m); // multiply sum of q8_1 values with m
+        }
+
+        sumi_d += sumi_d_sc * (sc & 0xF);
+    }
+
+    const float2 dm2f = __half22float2(dm2);
+
+    return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+#define VDR_Q3_K_Q8_1_MMQ  2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const int & scale_offset, const float & d3, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        const int isc = scale_offset + 2*i;
+
+        const int isc_low = isc % (QK_K/32);
+        const int sc_shift_low = 4 * (isc / (QK_K/32));
+        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+        const int isc_high = isc % (QK_K/64);
+        const int sc_shift_high = 2 * (isc / (QK_K/64));
+        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+        const int sc = (sc_low | sc_high) - 32;
+
+        const int vil = (vl >> (2*i)) & 0x03030303;
+
+        const int vih = ((vh >> i) << 2) & 0x04040404;
+
+        const int vi = __vsubss4(vil, vih);
+
+        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d3 * sumf;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d3, const float & d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+        int sumi_sc = 0;
+
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+        }
+
+        sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+    }
+
+    return d3*d8 * sumi;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+#define VDR_Q4_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR4_K; ++i) {
+        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+// contiguous u/y values
+// also used for q5_K
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < VDR_Q4_K_Q8_1_MMQ; i0 += (QI8_1/QR4_K)) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int i = i0; i < i0 + (QI8_1/QR4_K); ++i) {
+            sumi_d = __dp4a(v[2*i+0], u[2*i+0], sumi_d); // SIMD dot product
+            sumi_d = __dp4a(v[2*i+1], u[2*i+1], sumi_d); // SIMD dot product
+        }
+
+        const float2 ds8f = __half22float2(ds8[i0 / 4]);
+
+        sumf_d += ds8f.x * (sc[i0/4] * sumi_d);
+        sumf_m += ds8f.y *   m[i0/4]; // sum of q8_1 block * q4_K min val
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+#define VDR_Q5_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
+    const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+        const int v0i = vl0i | vh0i;
+        const int v1i = vl1i | vh1i;
+
+        const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);
+
+    }
+
+    const float2 dm5f = __half22float2(dm5);
+
+    return dm5f.x*sumf_d - dm5f.y*sumf_m;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+#define VDR_Q6_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        const int sc = scales[4*i];
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
+    const float & d6, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+        int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+        for (int i = i0; i < i0 + 2; ++i) {
+            sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
+            sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
+
+            sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
+            sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
+        }
+
+        sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
+    }
+
+    return d6 * sumf_d;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
 static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
 
@@ -1631,6 +1951,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q4_0_Q8_1_MMQ == 0);
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const float * x_dmf = (float *) x_dm;
@@ -1729,6 +2050,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q4_1_Q8_1_MMQ == 0);
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 
@@ -1848,10 +2170,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q5_0_Q8_1_MMQ == 0);
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
-    const float * x_dmf = (float *) x_dm;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
 
     int u[2*VDR_Q5_0_Q8_1_MMQ];
 
@@ -1862,7 +2186,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
     }
 
     return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
-        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
@@ -1965,6 +2289,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q5_1_Q8_1_MMQ == 0);
 
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
@@ -1989,12 +2314,13 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
     int v[VDR_Q8_0_Q8_1_MMVQ];
     int u[VDR_Q8_0_Q8_1_MMVQ];
 
+#pragma unroll
     for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
         v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
         u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
     }
 
-    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds);
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2065,43 +2391,14 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q8_0_Q8_1_MMQ == 0);
 
-    const float * x_dmf = (float *) x_dm;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
 
     return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
         (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
-         y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
-}
-
-#define VDR_q2_K_q8_1 1
-
-static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl(
-    const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
-    const half2 & dm, const float * __restrict__ d8) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-    for (int i = 0; i < QR2_K; ++i) {
-        const int sc = scales[2*i];
-
-        const int vi = (v >> (2*i)) & 0x03030303;
-
-        sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
-
-        int sc_high = sc >> 4;
-        sc_high |= sc_high <<  8;
-        sc_high |= sc_high << 16;
-        sumf_m += d8[i] * __dp4a(sc_high, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
-    }
-
-    const float2 dmf = __half22float2(dm);
-
-    return dmf.x*sumf_d - dmf.y*sumf_m;
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
@@ -2115,15 +2412,16 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
     const uint8_t * scales = bq2_K->scales + scale_offset;
 
     const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
-    int u[QR2_K];
+    int    u[QR2_K];
     float d8[QR2_K];
 
+#pragma unroll
     for (int i = 0; i < QR2_K; ++ i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
         d8[i] = bq8_1[bq8_offset + i].ds.x;
     }
 
-    return vec_dot_q2_K_q8_1_impl(v, u, scales, bq2_K->dm, d8);
+    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2204,62 +2502,26 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q2_K_Q8_1_MMQ == 0);
 
-    const int kbx  = k / QI2_K;
-    const int kqsx = k % QI2_K;
-
-    const int bq8_offset = QR2_K * (kqsx / QI8_1);
-    const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
+    const int kbx = k / QI2_K;
+    const int ky  = (k % QI2_K) * QR2_K;
+    const float * y_df = (const float *) y_ds;
 
-    const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4) + i / 4)) + kbx*16 + scale_offset;
+    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
 
-    int u[QR2_K];
-    float d8[QR2_K];
+    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
 
-    for (int l = 0; l < QR2_K; ++ l) {
-        const int y_qs_index = j * (QR2_K*WARP_SIZE) + kbx * (QR2_K*QI2_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1;
-        u[l]  = y_qs[y_qs_index];
-        d8[l] = y_ds[y_qs_index / QI8_1].x;
+#pragma unroll
+    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
+        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
     }
 
-    return vec_dot_q2_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], u, scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], d8);
-}
-
-#define VDR_q3_K_q8_1 1
-
-static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl(
-    const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
-    const int & scale_offset, const float & d, const float * __restrict__ d8) {
+    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
 
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    float sumf = 0.0f;
-
-    for (int i = 0; i < QR3_K; ++i) {
-        const int isc = scale_offset + 2*i;
-
-        const int isc_low = isc % (QK_K/32);
-        const int sc_shift_low = 4 * (isc / (QK_K/32));
-        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;
-
-        const int isc_high = isc % (QK_K/64);
-        const int sc_shift_high = 2 * (isc / (QK_K/64));
-        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
-
-        const int sc = (sc_low | sc_high) - 32;
-
-        const int vil = (vl >> (2*i)) & 0x03030303;
-
-        const int vih = ((vh >> i) << 2) & 0x04040404;
-
-        const int vi = __vsubss4(vil, vih);
-
-        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
-    }
-
-    return d*sumf;
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int index_y = j * (QR2_K*WARP_SIZE) + QR2_K*k;
+    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
@@ -2277,15 +2539,16 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
     // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
     const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
 
-    int u[QR3_K];
+    int    u[QR3_K];
     float d8[QR3_K];
 
+#pragma unroll
     for (int i = 0; i < QR3_K; ++i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
         d8[i] = bq8_1[bq8_offset + i].ds.x;
     }
 
-    return vec_dot_q3_K_q8_1_impl(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2330,6 +2593,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
     const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI3_K) {
@@ -2341,7 +2605,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
 
         const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd].x = bxi->d;
+        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
     }
 
 #pragma unroll
@@ -2354,7 +2618,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
 
         const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
 
-        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
+        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
     }
 
 #pragma unroll
@@ -2367,7 +2632,19 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
 
         const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
 
-        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8(bxi->scales, k % (QI3_K/4));
+        const int ksc = k % (QI3_K/4);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
     }
 }
 
@@ -2381,57 +2658,31 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q3_K_Q8_1_MMQ == 0);
 
     const int kbx  = k / QI3_K;
-    const int kqsx = k % QI3_K;
+    const int ky  = (k % QI3_K) * QR3_K;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
 
-    const int bq8_offset = QR3_K * (kqsx / (QI3_K/2));
-    const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
+    const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
 
-    const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4) + i / 4)) + kbx*16;
+    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
 
-    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
-    const int vh = ~x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + kqsx % (QI3_K/2)] >> bq8_offset;
-
-    int    u[QR3_K];
-    float d8[QR3_K];
-
-    for (int l = 0; l < QR3_K; ++ l) {
-        const int y_qs_index = j * (QR3_K*WARP_SIZE) + kbx * (QR3_K*QI3_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1;
-        u[l]  = y_qs[y_qs_index];
-        d8[l] = y_ds[y_qs_index / QI8_1].x;
-    }
-
-    return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset,
-                                  x_dm[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx].x, d8);
-}
-
-#define VDR_q4_K_q8_1 2
-
-static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl(
-    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
-    const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-    for (int i = 0; i < QR4_K; ++i) {
-        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
-        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+#pragma unroll
+    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
+        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+        const int shift = 2 * ((ky % 32) / 8);
+        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
 
-        const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
-        const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+        const int vlh = (vh << 2) & 0x04040404;
 
-        sumf_d += d8[i] * (dot1 * sc[i]);
-        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+        v[l] = __vsubss4(vll, vlh);
     }
 
-    return __half2float(dm4.x)*sumf_d - __half2float(dm4.y)*sumf_m;
-
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int index_y = j * (QR3_K*WARP_SIZE) + k*QR3_K;
+    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
@@ -2478,7 +2729,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
         u[2*i+1] = q8[4];
     }
 
-    return vec_dot_q4_K_q8_1_impl(v, u, sc, m, bq4_K->dm, d8);
+    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
 
 #else
 
@@ -2566,7 +2817,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;           // == 0 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_K) {
@@ -2591,7 +2842,15 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
 
         const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_uint8_aligned(bxi->scales, k % (QI4_K/8));
+        const int * scales = (int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
     }
 }
 
@@ -2605,76 +2864,20 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q4_K_Q8_1_MMQ == 0);
 
-    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI6_K; // == k if QK_K == 256
-
-    int    v[2];
-    int    u[2*QR4_K];
-    float d8[QR4_K];
-
-    // kqsx is in 0,2...30. bq8_offset = 2 * (kqsx/4) -> bq8_offset = 0, 2, 4, 6
-    const int bq8_offset = QR4_K * ((kqsx/2) / (QI8_1/2));
-
-    v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 0];
-    v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 4];
-
-    const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4];
-    uint16_t aux[2];
-    const int l = bq8_offset/2;
-    if (l < 2) {
-        aux[0] = scales[l+0] & 0x3f3f;
-        aux[1] = scales[l+2] & 0x3f3f;
-    } else {
-        aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2);
-        aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2);
-    }
-    const uint8_t * sc = (const uint8_t *)aux;
-    const uint8_t * m  = sc + 2;
-
-    for (int l = 0; l < QR4_K; ++l) {
-        const int kqsy = j * (QR4_K*WARP_SIZE) + kbx * (QR4_K*QI4_K) + (bq8_offset + l) * QI8_1 + (kqsx/2) % (QI8_1/2);
-        u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)];
-        u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)];
-        d8[l] = y_ds[kqsy / QI8_1].x;
-    }
-
-    return vec_dot_q4_K_q8_1_impl(v, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K + kbx], d8);
-}
-
-#define VDR_q5_K_q8_1 2
-
-static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
-    const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
-    const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-    for (int i = 0; i < QR5_K; ++i) {
-        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
-        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
-
-        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
-        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
-
-        const int v0i = vl0i | vh0i;
-        const int v1i = vl1i | vh1i;
-
-        const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
-        const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
-
-        sumf_d += d8[i] * (dot1 * sc[i]);
-        sumf_m += d8[i] * (dot2 * m[i]);
+    int v[QR4_K*VDR_Q4_K_Q8_1_MMQ];
 
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_K_Q8_1_MMQ; ++l) {
+        v[l + 0]         = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 0) & 0x0F0F0F0F;
+        v[l + (QI4_K/4)] = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 4) & 0x0F0F0F0F;
     }
 
-    return __half2float(dm5.x)*sumf_d - __half2float(dm5.y)*sumf_m;
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
 
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int index_y = j * (QR4_K*WARP_SIZE) + QR4_K*k;
+    return vec_dot_q4_K_q8_1_impl_mmq(v, &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
@@ -2711,6 +2914,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     const uint8_t * sc = (const uint8_t *)aux;
     const uint8_t * m  = sc + 2;
 
+#pragma unroll
     for (int i = 0; i < QR5_K; ++i) {
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
         d8[i] = bq8i->ds.x;
@@ -2767,14 +2971,12 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 
 static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
     __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_K) + GGML_CUDA_MMQ_Y/QI5_K];
-    __shared__ int   tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)     + GGML_CUDA_MMQ_Y/4];
     __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
-    *x_qh = tile_x_qh;
     *x_sc = tile_x_sc;
 }
 
@@ -2801,12 +3003,25 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
         }
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR5_K*kqsx;
 
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;           // == 0 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_K) {
@@ -2821,19 +3036,6 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
         x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
     }
 
-#pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
-        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI5_K/4);
-
-        x_qh[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8(bxi->qh, k % (QI5_K/4));
-    }
-
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
         int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
@@ -2844,7 +3046,15 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_uint8_aligned(bxi->scales, k % (QI5_K/8));
+        const int * scales = (int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
     }
 }
 
@@ -2858,71 +3068,13 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q5_K_Q8_1_MMQ == 0);
 
-    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI6_K; // == k if QK_K == 256
-
-    int   vl[2];
-    int   vh[2];
-    int    u[2*QR4_K];
-    float d8[QR4_K];
-
-    const int bq8_offset = QR5_K * ((kqsx/2) / (QI8_1/2));
-
-    vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 0];
-    vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 4];
-
-    vh[0] = x_qh[i * (WARP_SIZE/4) + i/4 + (kqsx/2) % 4 + 0] >> bq8_offset;
-    vh[1] = x_qh[i * (WARP_SIZE/4) + i/4 + (kqsx/2) % 4 + 4] >> bq8_offset;
-
-    const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4];
-    uint16_t aux[2];
-    const int l = bq8_offset/2;
-    if (l < 2) {
-        aux[0] = scales[l+0] & 0x3f3f;
-        aux[1] = scales[l+2] & 0x3f3f;
-    } else {
-        aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2);
-        aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2);
-    }
-    const uint8_t * sc = (const uint8_t *)aux;
-    const uint8_t * m  = sc + 2;
-
-    for (int l = 0; l < QR5_K; ++l) {
-        const int kqsy = j * (QR5_K*WARP_SIZE) + kbx * (QR5_K*QI5_K) + (bq8_offset + l) * QI8_1 + (kqsx/2) % (QI8_1/2);
-        u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)];
-        u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)];
-        d8[l] = y_ds[kqsy / QI8_1].x;
-    }
-
-    return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K + kbx], d8);
-}
-
-#define VDR_q6_K_q8_1 1
-
-static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl(
-    const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
-    const float & d, const float * __restrict__ d8) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    float sumf = 0.0f;
-
-    for (int i = 0; i < QR6_K; ++i) {
-        const int sc = scales[4*i];
-
-        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
-
-        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
-
-        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
-
-        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
-    }
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
 
-    return d*sumf;
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
+    const int index_y = j * (QR5_K*WARP_SIZE)     + QR5_K*k;
+    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
@@ -2942,24 +3094,23 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
     int    u[QR6_K];
     float d8[QR6_K];
 
+#pragma unroll
     for (int i = 0; i < QR6_K; ++i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
         d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
     }
 
-    return vec_dot_q6_K_q8_1_impl(vl, vh, u, scales, bq6_K->d, d8);
+    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
 }
 
 static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
     __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K) + GGML_CUDA_MMQ_Y/QI6_K];
-    __shared__ int   tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)     + GGML_CUDA_MMQ_Y/2];
     __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
-    *x_qh = tile_x_qh;
     *x_sc = tile_x_sc;
 }
 
@@ -2986,12 +3137,26 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
         }
 
         const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR6_K*kqsx;
+
+        const int ql = get_int_from_uint8(bxi->ql, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
+        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
 
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->ql, kqsx);
+        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
+        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;           // == 0 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI6_K) {
@@ -3003,20 +3168,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
 
         const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd].x = bxi->d;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
-        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI6_K/2);
-
-        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = get_int_from_uint8(bxi->qh, k % (QI6_K/2));
+        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
     }
 
 #pragma unroll
@@ -3043,33 +3195,19 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
     __builtin_assume(j <  WARP_SIZE);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q6_K_Q8_1_MMQ == 0);
 
-    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI6_K; // == k if QK_K == 256
-
-    const int bq8_offset = 2 * QR6_K * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/4);
-    const int scale_offset = (QI6_K/4) * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/8);
-    const int vh_shift = 2 * ((kqsx % (QI6_K/2)) / (QI6_K/4));
-
-    const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI6_K/2) + (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)] >> vh_shift;
-
-    const int x_sc_offset = i * (WARP_SIZE/8) + i/8 + kbx * (QI6_K/8);
-    const int8_t * scales = ((int8_t *) (x_sc + x_sc_offset)) + scale_offset;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
 
-    int    u[QR6_K];
-    float d8[QR6_K];
-
-    for (int l = 0; l < QR6_K; ++l) {
-        const int kqsy = j * (QR6_K*WARP_SIZE) + kbx * (QR6_K*QI6_K) + (bq8_offset + 2*l)*QI8_1 + kqsx % QI8_1;
-        u[l]  = y_qs[kqsy];
-        d8[l] = y_ds[kqsy / QI8_1].x;
-    }
+    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
 
-    return vec_dot_q6_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales,
-                                  x_dm[i * (WARP_SIZE/QI6_K) + i/QI6_K + kbx].x, d8);
+    const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
+    const int index_y = j * (QR6_K*WARP_SIZE)     + QR6_K*k;
+    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
 }
 
-template <int qk, int qr, int qi, typename block_q_t,
+template <int qk, int qr, int qi, bool need_sum, typename block_q_t,
               allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
 static __global__ void mul_mat_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
@@ -3130,7 +3268,16 @@ static __global__ void mul_mat_q(
             const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE;
             const int kby = tid_x % blocks_per_tile_y_col;
             const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
-            tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby] = y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds;
+
+            // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+            const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds;
+            half2       * dsi_dst = &tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby];
+            if (need_sum) {
+                *dsi_dst = *dsi_src;
+            } else {
+                float * dfi_dst = (float *) dsi_dst;
+                *dfi_dst = (*dsi_src).x;
+            }
         }
 
         __syncthreads();
@@ -3780,7 +3927,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1>
+    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3789,7 +3936,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1>
+    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3798,7 +3945,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1>
+    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3807,7 +3954,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1>
+    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3816,7 +3963,7 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1>
+    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -3873,10 +4020,10 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+        mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+        mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3891,10 +4038,10 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+        mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+        mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3909,10 +4056,10 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+        mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+        mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3927,10 +4074,10 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+        mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+        mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3945,10 +4092,10 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+        mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+        mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3963,10 +4110,10 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR2_K, QI2_K, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<false>, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<false>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK_K, QR2_K, QI2_K, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<true>, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<true>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3981,10 +4128,10 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR3_K, QI3_K, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<false>, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<false>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK_K, QR3_K, QI3_K, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<true>, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<true>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -3999,10 +4146,10 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR4_K, QI4_K, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<false>, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<false>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK_K, QR4_K, QI4_K, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<true>, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<true>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -4017,10 +4164,10 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR5_K, QI5_K, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<false>, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<false>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK_K, QR5_K, QI5_K, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<true>, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<true>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
@@ -4035,10 +4182,10 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
     const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
 
     if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR6_K, QI6_K, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<false>, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<false>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     } else {
-        mul_mat_q<QK_K, QR6_K, QI6_K, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<true>, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat>
+        mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<true>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
             <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }