]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: faster q2_K, q3_K MMQ + int8 tensor cores (llama/7921)
authorJohannes Gäßler <redacted>
Fri, 14 Jun 2024 16:41:49 +0000 (18:41 +0200)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
* CUDA: faster q2_K, q3_K MMQ + int8 tensor cores

* try CI fix

* try CI fix

* try CI fix

* fix data race

* rever q2_K precision related changes

src/ggml-cuda.cu
src/ggml-cuda/argsort.cu
src/ggml-cuda/common.cuh
src/ggml-cuda/mmq.cuh
src/ggml-cuda/softmax.cu
src/ggml-cuda/vecdotq.cuh

index 64d3b6747fc41e4cb9bbbf4af8941e06854fb9e0..593fa4cdaa51431b4c2773a7686c613ffd5dbdaf 100644 (file)
@@ -188,13 +188,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.default_tensor_split[id] = total_vram;
         total_vram += prop.totalGlobalMem;
 
+        info.devices[id].nsm   = prop.multiProcessorCount;
+        info.devices[id].smpb  = prop.sharedMemPerBlock;
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+        info.devices[id].smpbo = prop.sharedMemPerBlock;
         info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
 #else
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-        info.devices[id].smpb = prop.sharedMemPerBlock;
-        info.devices[id].nsm  = prop.multiProcessorCount;
     }
 
     for (int id = 0; id < info.device_count; ++id) {
index 1641440617779e9da3de304e678f475bd569675a..15757ca18e4d7a9a390d62eac62cdabb0acca66d 100644 (file)
@@ -73,6 +73,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
     const dim3 block_nums(1, nrows, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
 
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
     GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
 
     if (order == GGML_SORT_ORDER_ASC) {
index 7f4764d60e854a0c2b63ed28f8059820b290bccc..de7c2e4349ede23b6365e3abfee4f2b75d9dd1e8 100644 (file)
@@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
 #define FP16_AVAILABLE
 #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
 
+#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+#define FAST_FP16_AVAILABLE
+#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
 #define FP16_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
@@ -661,6 +665,7 @@ struct ggml_cuda_device_info {
         int     cc;                 // compute capability
         int     nsm;                // number of streaming multiprocessors
         size_t  smpb;               // max. shared memory per block
+        size_t  smpbo;              // max. shared memory per block (with opt-in)
         bool    vmm;                // virtual memory support
         size_t  vmm_granularity;    // granularity of virtual memory
         size_t  total_vram;
index 01e2086b41646936a836ffdbc1645b2f7132ba0d..6d57974fb4e7c45c6997a9a511bfd5d8d6af8f48 100644 (file)
 #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
 
 typedef void (*load_tiles_mmq_t)(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
 typedef void (*vec_dot_mmq_t)(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0);
 typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
 
@@ -25,9 +25,8 @@ static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected b
 static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
 
 struct tile_x_sizes {
-    int ql;
+    int qs;
     int dm;
-    int qh;
     int sc;
 };
 
@@ -67,16 +66,16 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 
-#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0,                           0}
-#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0,                           0}
-#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0,                           0}
-#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0,                           0}
-#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0,                           0}
-#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0,                           mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
+#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
+#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
+#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE       + mmq_y,       0}
+#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
 
 #define GET_TILE_X_SIZES_BODY                           \
     return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
@@ -89,7 +88,7 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
         type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K :    \
         type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K :    \
         type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K :    \
-        tile_x_sizes{0, 0, 0, 0}
+        tile_x_sizes{0, 0, 0}
 
 static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
     GET_TILE_X_SIZES_BODY;
@@ -103,9 +102,9 @@ static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type)
 // ------------------------------------------------------------
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const int kbx  = threadIdx.x / QI4_0;
     const int kqsx = threadIdx.x % QI4_0;
@@ -122,7 +121,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -144,10 +143,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const float * x_df = (const float *) x_dm;
     const int   * y_qs = (const int   *) y + 4;
@@ -172,7 +170,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
-                (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
+                (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
                 y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
@@ -180,10 +178,10 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -205,7 +203,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
         const int k     = k0 + mma_A::get_k(l) % QI4_0;
         const int shift =   4*(mma_A::get_k(l) / QI4_0);
 
-        A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+        A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
     }
 #pragma unroll
     for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -240,12 +238,16 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
             sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const int kbx  = threadIdx.x / QI4_1;
     const int kqsx = threadIdx.x % QI4_1;
@@ -260,7 +262,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -282,10 +284,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
@@ -309,7 +310,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
-                (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
+                (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
                 y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
@@ -317,10 +318,10 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -341,7 +342,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
         const int k     = k0 + mma_A::get_k(l) % QI4_0;
         const int shift =   4*(mma_A::get_k(l) / QI4_0);
 
-        A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
+        A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
     }
 #pragma unroll
     for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -377,12 +378,16 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
             sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const int kbx  = threadIdx.x / QI5_0;
     const int kqsx = threadIdx.x % QI5_0;
@@ -407,7 +412,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
         qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
 
-        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
 
         int qs1 = (ql >>  4)   & 0x0F0F0F0F;
         qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
@@ -416,7 +421,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
@@ -439,10 +444,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const float * x_dmf = (const float *) x_dm;
     const int   * y_qs  = (const int   *) y + 4;
@@ -468,17 +472,17 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
-                (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -499,7 +503,7 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
         const int i     =    i0 + mma_A::get_i(l);
         const int k     = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
 
-        A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+        A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
     }
 #pragma unroll
     for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -534,12 +538,16 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
             sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const int kbx  = threadIdx.x / QI5_1;
     const int kqsx = threadIdx.x % QI5_1;
@@ -563,7 +571,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
         qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
 
-        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
 
         int qs1 = (ql >>  4) & 0x0F0F0F0F;
         qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
@@ -571,7 +579,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -593,10 +601,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const int   * y_qs  = (const int   *) y + 4;
     const half2 * y_ds  = (const half2 *) y;
@@ -621,17 +628,17 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-                (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -651,7 +658,7 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
         const int i     =    i0 + mma_A::get_i(l);
         const int k     = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
 
-        A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+        A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
     }
 #pragma unroll
     for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -687,13 +694,16 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
             sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const int kbx  = threadIdx.x / QI8_0;
     const int kqsx = threadIdx.x % QI8_0;
@@ -709,7 +719,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
@@ -731,10 +741,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+    GGML_UNUSED(x_sc);
 
     const float * x_dmf = (const float *) x_dm;
     const int   * y_qs  = (const int   *) y + 4;
@@ -749,7 +758,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
             const int i = i0 + threadIdx.x;
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
-                (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+                (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
                 y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
         }
     }
@@ -757,10 +766,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
+    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -781,7 +790,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
         const int i = i0 + mma_A::get_i(l);
         const int k = k0 + mma_A::get_k(l);
 
-        A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
+        A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
     }
 #pragma unroll
     for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -816,12 +825,15 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh);
 
     const int kbx  = threadIdx.x / QI2_K;
     const int kqsx = threadIdx.x % QI2_K;
@@ -836,48 +848,42 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
-    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+        const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
-        int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+        for (int l = 0; l < QR2_K; ++l) {
+            const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
 
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
-    }
+            int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
+            x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
+            x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
 
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
-        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+            if (kqsx % QR2_K != 0) {
+                continue;
+            }
 
-        if (need_check) {
-            i = min(i, i_max);
+            x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
         }
 
-        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
+        const int sc_m = bxi->scales[kqsx];
+#ifdef FAST_FP16_AVAILABLE
+        const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
+#else
+        const float2 bxi_dmf = __half22float2(bxi->dm);
+        const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
+#endif // FAST_FP16_AVAILABLE
 
-        x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
+        x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    GGML_UNUSED(x_qh);
-
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -887,30 +893,99 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const int kbx = k0 / QI2_K;
-            const int ky  = (k0 % QI2_K) * QR2_K;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+                &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
+                &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
-            int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
 
-            const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
-            const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+    mma_A   A[2];
+    float  dA[mma_C::ne/2][2];
+    float  mA[mma_C::ne/2][2];
 
 #pragma unroll
-            for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
-                v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
-            }
+    for (int l = 0; l < mma_A::ne; ++l) {
+        const int i = i0 + mma_A::get_i(l);
+        const int shift = 2*mma_A::get_k(l);
 
-            const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+        A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
+        A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
+    }
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
-                v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
-                x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+#pragma unroll
+        for (int kk = 0; kk < 2; ++kk) {
+            const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
+
+            dA[l][kk] = dm.x;
+            mA[l][kk] = dm.y;
         }
     }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        mma_C Cd[2];
+        mma_C Cm[2];
+        mma_B B[2];
+        float dB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int j = j0 + mma_B::get_j(l);
+            const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+            B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
+            B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
+        }
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
+        }
+
+        Cd[0].mma_K4(A[0], B[0]);
+        Cd[1].mma_K4(A[1], B[1]);
+
+        mma_A A1;
+        A1.x[0] = 0x01010101;
+        A1.x[1] = 0x01010101;
+        Cm[0].mma_K4(A1, B[0]);
+        Cm[1].mma_K4(A1, B[1]);
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
+        }
+    }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
     const int kbx  = threadIdx.x / QI3_K;
@@ -926,7 +1001,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+        const int x_ql_0 = get_int_from_uint8(bxi->qs,    kqsx);
+        const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+
+#pragma unroll
+        for (int l = 0; l < QR3_K; ++l) {
+            const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
+
+            const int x_ql_k =  (x_ql_0 >> (2*l))       & 0x03030303;
+            const int x_qh_k = ((x_qh_0 >>    l)  << 2) & 0x04040404;
+
+            int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
+            x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
+
+            if (kqsx % 2 != 0) {
+                continue;
+            }
+
+            x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
+        }
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
@@ -946,20 +1039,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
     }
 
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
-        int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
-
-        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
-        x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
-    }
-
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
         int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
@@ -987,13 +1066,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    const float * x_df = (const float *) x_dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -1008,31 +1087,102 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
 
             const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
 
-            int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+                &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
+                x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const float * x_df = (const float *) x_dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+    mma_A   A[2];
+    int   scA[mma_C::ne/2][2];
+    float  dA[mma_C::ne/2];
 
 #pragma unroll
-            for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
-                const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
-                const int shift = 2 * ((ky % 32) / 8);
-                const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+    for (int l = 0; l < mma_A::ne; ++l) {
+        const int i = i0 + mma_A::get_i(l);
+        const int k = QR3_K*k0 + mma_A::get_k(l);
 
-                const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
-                const int vlh = (vh << 2) & 0x04040404;
+        A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0]          >> (4*(k%2))) & 0x0F0F0F0F;
+        A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
+        A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
+        A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
+    }
 
-                v[l] = __vsubss4(vll, vlh);
-            }
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
-                v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
-                x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
+        const int kbx  = k0 / QI3_K;
+        const int ky  = (k0 % QI3_K) * QR3_K;
+        const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+        scA[l][0] = sc[0];
+        scA[l][1] = sc[1];
+    }
+
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        mma_C C[2];
+        mma_B B[2];
+        float dB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int j = j0 + mma_B::get_j(l);
+            const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+            B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
+            B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
+        }
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
+        }
+
+        C[0].mma_K4(A[0], B[0]);
+        C[1].mma_K4(A[1], B[1]);
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh);
 
     const int kbx  = 0;           // threadIdx.x / QI4_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
@@ -1047,7 +1197,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
 
-        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_K;  // == 1 if QK_K == 256
@@ -1090,11 +1240,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    GGML_UNUSED(x_qh);
-
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -1109,7 +1257,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
             const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
-                &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
+                &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
                 x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
@@ -1117,10 +1265,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -1143,7 +1290,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
             const int i = i0 + mma_A::get_i(l);
             const int k = k0 + mma_A::get_k(l);
 
-            A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
+            A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
         }
 
 #pragma unroll
@@ -1204,12 +1351,15 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
             sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh);
 
     const int kbx  = 0;           // threadIdx.x / QI5_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
@@ -1236,8 +1386,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
 
-        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
-        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+        x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+        x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_K;  // == 1 if QK_K == 256
@@ -1280,11 +1430,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    GGML_UNUSED(x_qh);
-
     const int   * y_qs  = (const int   *) y + 4;
     const half2 * y_ds  = (const half2 *) y;
 
@@ -1299,7 +1447,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
             const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
-                &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
+                &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
                 x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
@@ -1307,10 +1455,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -1333,7 +1480,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
             const int i = i0 + mma_A::get_i(l);
             const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
 
-            A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
+            A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
         }
 
 #pragma unroll
@@ -1394,12 +1541,15 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
             sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
-    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_qh);
 
     const int kbx  = 0;           // threadIdx.x / QI6_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
@@ -1426,8 +1576,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
         const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
 
-        x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
-        x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+        x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
@@ -1463,11 +1613,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    GGML_UNUSED(x_qh);
-
     const float * x_dmf = (const float *) x_dm;
     const int   * y_qs  = (const int   *) y + 4;
     const float * y_df  = (const float *) y;
@@ -1483,7 +1631,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
             const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
-                &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
+                &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
                 x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
@@ -1491,10 +1639,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+#ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
     typedef mma_int_B_J8K4  mma_B;
@@ -1505,7 +1652,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const float * y_df = (const float *) y;
 
     const int i0 = threadIdx.y*mma_A::I;
+#ifdef INT8_MMA_AVAILABLE
     static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+#endif // INT8_MMA_AVAILABLE
 
     mma_A   A[4];
     int   scA[mma_C::ne/2][4];
@@ -1517,8 +1666,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
             const int i = i0 + mma_A::get_i(l);
             const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
 
-            A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
-            A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
+            A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
+            A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
         }
 
 #pragma unroll
@@ -1578,6 +1727,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
             sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
         }
     }
+#else
+    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
 }
 
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1608,7 +1761,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
     typedef mma_int_C_I16J8 mma_C;
 
     const int i0 = threadIdx.y*mma_C::I;
+#ifdef INT8_MMA_AVAILABLE
     static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+#endif // INT8_MMA_AVAILABLE
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
@@ -1638,125 +1793,85 @@ struct mmq_type_traits;
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
-    static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q4_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
-    static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q4_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
-    static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
-    static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q5_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
-    static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q8_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
-    static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+    static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q2_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
-    static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+    static constexpr int              vdr          = VDR_Q3_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
-    static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q4_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
-    static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q5_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
-    static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
-#ifdef INT8_MMA_AVAILABLE
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
-    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    static constexpr int              vdr          = VDR_Q6_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q6_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
-static int mmq_need_sum(const ggml_type type_x) {
+static bool mmq_need_sum(const ggml_type type_x) {
     switch (type_x) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
@@ -1790,7 +1905,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #if __CUDA_ARCH__ >= CC_VOLTA
     __launch_bounds__(WARP_SIZE*nwarps, 1)
 #else
-    __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 static __global__ void mul_mat_q(
@@ -1809,16 +1924,21 @@ static __global__ void mul_mat_q(
     constexpr int              mmq_y      = get_mmq_y_device(mmq_x);
     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
-    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
-    constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
+
+#ifdef INT8_MMA_AVAILABLE
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
+    constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
+    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 
     constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
 
     extern __shared__ char data_mul_mat_q[];
-    int   * tile_x_ql = (int   *)  data_mul_mat_q;
-    half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
-    int   * tile_x_qh = (int   *) (tile_x_dm + txs.dm);
-    int   * tile_x_sc = (int   *) (tile_x_qh + txs.qh);
+    int   * tile_x_qs = (int   *)  data_mul_mat_q;
+    half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
+    int   * tile_x_sc = (int   *) (tile_x_dm + txs.dm);
     int   * tile_y    = (int   *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
 
     const int blocks_per_row_x = ne00 / qk;
@@ -1834,7 +1954,7 @@ static __global__ void mul_mat_q(
 
     for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
 
-        load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
+        load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
 
 #pragma unroll
         for (int kr = 0; kr < qr; ++kr) {
@@ -1850,7 +1970,7 @@ static __global__ void mul_mat_q(
 
 // #pragma unroll // unrolling this loop causes too much register pressure
             for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
-                vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
+                vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
             }
 
             __syncthreads();
@@ -1867,6 +1987,19 @@ struct mmq_args {
     int64_t ne0;
 };
 
+constexpr int mmq_get_nwarps(int mmq_x) {
+    return mmq_x >= 32 ? 8 : 4;
+}
+
+static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
+    const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
+    const int nwarps = mmq_get_nwarps(mmq_x);
+
+    const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
+    return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
+}
+
 template <ggml_type type, int mmq_x, int nwarps>
 static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
     const int id = ggml_cuda_get_device();
@@ -1878,10 +2011,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
     const dim3 block_nums(block_num_x, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, nwarps, 1);
 
-    const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
-    const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
-    const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
-    const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
+    const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
 
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1905,9 +2035,10 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
 
 template <ggml_type type>
 void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
-    const int id = ggml_cuda_get_device();
-    const int nsm = ggml_cuda_info().devices[id].nsm;
-    const int cc  = ggml_cuda_info().devices[id].cc;
+    const int id    = ggml_cuda_get_device();
+    const int nsm   = ggml_cuda_info().devices[id].nsm;
+    const int cc    = ggml_cuda_info().devices[id].cc;
+    const int smpbo = ggml_cuda_info().devices[id].smpbo;
 
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
@@ -1920,7 +2051,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
         const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
         const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
 
-        if (nwaves < nwaves_best) {
+        if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
             mmq_x_best  = mmq_x;
             nwaves_best = nwaves;
         }
@@ -1928,54 +2059,55 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
 
     switch (mmq_x_best) {
         case   8:
-            launch_mul_mat_q<type,   8, 4>(args, stream);
+            launch_mul_mat_q<type,   8, mmq_get_nwarps(  8)>(args, stream);
             break;
         case  16:
-            launch_mul_mat_q<type,  16, 4>(args, stream);
+            launch_mul_mat_q<type,  16, mmq_get_nwarps( 16)>(args, stream);
             break;
         case  24:
-            launch_mul_mat_q<type,  24, 4>(args, stream);
+            launch_mul_mat_q<type,  24, mmq_get_nwarps( 24)>(args, stream);
             break;
         case  32:
-            launch_mul_mat_q<type,  32, 8>(args, stream);
+            launch_mul_mat_q<type,  32, mmq_get_nwarps( 32)>(args, stream);
             break;
         case  40:
-            launch_mul_mat_q<type,  40, 8>(args, stream);
+            launch_mul_mat_q<type,  40, mmq_get_nwarps( 40)>(args, stream);
             break;
         case  48:
-            launch_mul_mat_q<type,  48, 8>(args, stream);
+            launch_mul_mat_q<type,  48, mmq_get_nwarps( 48)>(args, stream);
             break;
         case  56:
-            launch_mul_mat_q<type,  56, 8>(args, stream);
+            launch_mul_mat_q<type,  56, mmq_get_nwarps( 56)>(args, stream);
             break;
         case  64:
-            launch_mul_mat_q<type,  64, 8>(args, stream);
+            launch_mul_mat_q<type,  64, mmq_get_nwarps( 64)>(args, stream);
             break;
         case  72:
-            launch_mul_mat_q<type,  72, 8>(args, stream);
+            launch_mul_mat_q<type,  72, mmq_get_nwarps( 72)>(args, stream);
             break;
         case  80:
-            launch_mul_mat_q<type,  80, 8>(args, stream);
+            launch_mul_mat_q<type,  80, mmq_get_nwarps( 80)>(args, stream);
             break;
         case  88:
-            launch_mul_mat_q<type,  88, 8>(args, stream);
+            launch_mul_mat_q<type,  88, mmq_get_nwarps( 88)>(args, stream);
             break;
         case  96:
-            launch_mul_mat_q<type,  96, 8>(args, stream);
+            launch_mul_mat_q<type,  96, mmq_get_nwarps( 96)>(args, stream);
             break;
         case 104:
-            launch_mul_mat_q<type, 104, 8>(args, stream);
+            launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
             break;
         case 112:
-            launch_mul_mat_q<type, 112, 8>(args, stream);
+            launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
             break;
         case 120:
-            launch_mul_mat_q<type, 120, 8>(args, stream);
+            launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
             break;
         case 128:
-            launch_mul_mat_q<type, 128, 8>(args, stream);
+            launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
             break;
         default:
+            fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
             GGML_ASSERT(false);
             break;
     }
index ce64f2f2ce28b1095f7497a96471107c5d4a87c4..c24abae1f138c6fd6d5999f5f1d09139fd0b4eba 100644 (file)
@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
     if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:
index b9573a7c7d053d4e5707d091f2898bc414bd937d..3b12d656616be491cc88cb4e7c8652f25a1ba696 100644 (file)
@@ -265,36 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
 
 // contiguous u/y values
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
-    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
-    const half2 & dm2, const float & d8) {
+    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    int sumi_d = 0;
-    int sumi_m = 0;
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
 
 #pragma unroll
     for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
-        int sumi_d_sc = 0;
-
-        const int sc = scales[i0 / (QI8_1/2)];
-
-        // fill int with 4x m
-        int m = sc >> 4;
-        m |= m <<  8;
-        m |= m << 16;
+        const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
+        int sumi_d = 0;
+        int sumi_m = 0;
 
+        const int vi0 = v[i0/(QI8_1/2)];
 #pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
-            sumi_m    = __dp4a(m,    u[i], sumi_m); // multiply sum of q8_1 values with m
+            const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
+            sumi_d = __dp4a(vi,         u[i], sumi_d); // SIMD dot product
+            sumi_m = __dp4a(0x01010101, u[i], sumi_m);
         }
 
-        sumi_d += sumi_d_sc * (sc & 0xF);
+        sumf_d += dm2f.x * sumi_d;
+        sumf_m += dm2f.y * sumi_m;
     }
 
-    const float2 dm2f = __half22float2(dm2);
-
-    return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
+    return d8*(sumf_d - sumf_m);
 #else
     NO_DEVICE_CODE;
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -352,8 +347,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
     for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
         int sumi_sc = 0;
 
+#pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+            const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
+            sumi_sc = __dp4a(vi, u[i], sumi_sc); // SIMD dot product
         }
 
         sumi += sumi_sc * scales[i0 / (QI8_1/2)];