]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: optimize MMQ int8 tensor core performance (#8062)
authorJohannes Gäßler <redacted>
Mon, 24 Jun 2024 10:41:23 +0000 (12:41 +0200)
committerGitHub <redacted>
Mon, 24 Jun 2024 10:41:23 +0000 (12:41 +0200)
* CUDA: optimize MMQ int8 tensor core performance

* only a single get_mma_tile_x_k function

* simplify code, make functions constexpr

ggml-cuda/common.cuh
ggml-cuda/mma.cuh
ggml-cuda/mmq.cuh

index 5bd24ebe5fa79d83f075754e910f3372c030385a..5c866253586e7a99322b2b300b51cf566025c91b 100644 (file)
@@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
     static constexpr int qi = QI3_S;
 };
 
-static int get_mmq_x_max_host(const int cc) {
+static constexpr int get_mmq_x_max_host(int cc) {
 #ifdef CUDA_USE_TENSOR_CORES
     return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
 #else
@@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
 }
 
 // Round rows to this value for --split-mode row:
-static int get_mmq_y_host(const int cc) {
+static constexpr int get_mmq_y_host(int cc) {
     return cc >= CC_VOLTA ? 128 : 64;
 }
 
index 63e07fbc21291013fac4bab8d188fb8bb6c58a5d..0301a52f9ea719ad651a9a9e2ebd1a6b5e4895fe 100644 (file)
@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_A_I16K8 {
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_B_J8K4 {
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+            : "+r"(x[0])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_B_J8K8 {
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_C_I16J8 {
index e2d07c20253ae3529ea24c71a61ca095cb68edfc..0f7f8ae51e425c65c69b88395bff5f868e53597d 100644 (file)
@@ -7,15 +7,8 @@
 #include <climits>
 #include <cstdint>
 
-#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
-#define MMQ_NWARPS 8
-
-typedef void (*load_tiles_mmq_t)(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
-typedef void (*vec_dot_mmq_t)(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
 typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
 
 struct block_q8_1_mmq {
@@ -63,51 +56,101 @@ static constexpr __device__ int get_mmq_y_device() {
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 }
 
-#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
-#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
-#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
-#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
-#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
-#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE       + mmq_y,       0}
-#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-
-#define GET_TILE_X_SIZES_BODY                           \
-    return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
-        type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 :    \
-        type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 :    \
-        type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 :    \
-        type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 :    \
-        type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K :    \
-        type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K :    \
-        type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K :    \
-        type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K :    \
-        type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K :    \
-        tile_x_sizes{0, 0, 0}
-
-static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
-    GET_TILE_X_SIZES_BODY;
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
+#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE       + mmq_y,       0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
+        type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
+        type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
+        type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
+        type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
+        type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
+        type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
+        type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
+        type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+        tile_x_sizes{0, 0, 0};
 }
 
-template <int mmq_y>
-static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
-    GET_TILE_X_SIZES_BODY;
+#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0               + 4)
+#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1               + 4)
+#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0               + 4)
+#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1               + 4)
+#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0               + 0)
+#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE                     + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2)
+#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
+
+static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+
+static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
+        type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
+        type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
+        type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
+        type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
+        type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
+        type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
+        type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+        0;
+}
+
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+#define MMQ_NWARPS 8
+
+static int mmq_get_granularity_host(const int mmq_x, const int cc) {
+    return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+}
+
+#ifdef INT8_MMA_AVAILABLE
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+    return mmq_x >= 48 ? 16 : 8;
 }
+#else
+static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+    return 8;
+}
+#endif // INT8_MMA_AVAILABLE
 
 // ------------------------------------------------------------
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_0;
     const int kqsx = threadIdx.x % QI4_0;
 
-    float * x_dmf = (float *) x_dm;
-
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + threadIdx.y;
@@ -118,7 +161,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -134,17 +181,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q4_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_df = (const float *) x_dm;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -175,76 +226,90 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A;
-    float dA[mma_C::ne/2];
+    mma_A A[ntx];
+    float dA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     = i0 + mma_A::get_i(l);
-        const int k     = k0 + mma_A::get_k(l) % QI4_0;
-        const int shift =   4*(mma_A::get_k(l) / QI4_0);
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i     = i0 + n*mma_A::I + mma_A::get_i(l);
+            const int k     = k0              + mma_A::get_k(l) % QI4_0;
+            const int shift =                4*(mma_A::get_k(l) / QI4_0);
+
+            A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+        }
 
-        A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
-        mma_B B;
-        half2 dsB[mma_C::ne/2];
-
 #pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        mma_B  B;
+        float dB[mma_C::ne/2];
+
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
 
-            dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+            dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_1;
     const int kqsx = threadIdx.x % QI4_1;
@@ -259,7 +324,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -275,16 +344,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q4_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -315,51 +389,53 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_A_I16K4 mma_A_K4;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A;
-    half2 dmA[mma_C::ne/2];
+    mma_A A[ntx];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     = i0 + mma_A::get_i(l);
-        const int k     = k0 + mma_A::get_k(l) % QI4_0;
-        const int shift =   4*(mma_A::get_k(l) / QI4_0);
+    for (int n = 0; n < ntx; ++n) {
+        ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1);
+        A[n].x[2]  = (A[n].x[0] >> 4) & 0x0F0F0F0F;
+        A[n].x[3]  = (A[n].x[1] >> 4) & 0x0F0F0F0F;
+        A[n].x[0] &= 0x0F0F0F0F;
+        A[n].x[1] &= 0x0F0F0F0F;
 
-        A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -367,24 +443,35 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
-            sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_0;
     const int kqsx = threadIdx.x % QI5_0;
@@ -409,8 +496,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
         qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
-
         int qs1 = (ql >>  4)   & 0x0F0F0F0F;
         qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
         qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
@@ -418,12 +503,17 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
@@ -435,19 +525,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q5_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -457,70 +551,57 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
-            const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
-
-            int u[2*VDR_Q5_0_Q8_1_MMQ];
-
-#pragma unroll
-            for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
-            }
-
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
-                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
+                 x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    mma_A A;
-    float dA[mma_C::ne/2];
+    mma_A A[ntx];
+    float dA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     =    i0 + mma_A::get_i(l);
-        const int k     = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
+    for (int n = 0; n < ntx; ++n) {
+        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
 
-        A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
 
-        dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -528,23 +609,34 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_1;
     const int kqsx = threadIdx.x % QI5_1;
@@ -568,15 +660,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
         qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
-
         int qs1 = (ql >>  4) & 0x0F0F0F0F;
         qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
         qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -592,18 +688,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q5_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const int   * y_qs  = (const int   *) y + 4;
-    const half2 * y_ds  = (const half2 *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -613,69 +714,57 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
-            const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
-
-            int u[2*VDR_Q5_1_Q8_1_MMQ];
-
-#pragma unroll
-            for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
-            }
-
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
+                 x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A;
-    half2 dmA[mma_C::ne/2];
+    mma_A A[ntx];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     =    i0 + mma_A::get_i(l);
-        const int k     = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
+    for (int n = 0; n < ntx; ++n) {
+        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
 
-        A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -683,28 +772,38 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
-            sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_tile + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI8_0;
     const int kqsx = threadIdx.x % QI8_0;
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -716,7 +815,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
@@ -732,19 +835,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0         + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -755,7 +862,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
             const int i = i0 + threadIdx.x;
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
-                (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+                (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
                 y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
         }
     }
@@ -763,51 +870,48 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    mma_A A;
-    float dA[mma_C::ne/2];
+    mma_A A[ntx];
+    float dA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i = i0 + mma_A::get_i(l);
-        const int k = k0 + mma_A::get_k(l);
+    for (int n = 0; n < ntx; ++n) {
+        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
 
-        A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j = j0 + mma_B::get_j(l);
-            const int k = k0 + mma_B::get_k(l);
+        B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -815,22 +919,34 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI2_K;
     const int kqsx = threadIdx.x % QI2_K;
@@ -859,7 +975,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
                 continue;
             }
 
-            x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
+#else
+            x_qs[i*(WARP_SIZE + 1)       + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -870,15 +990,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-        x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik;
+#else
+        x_dm[i*(WARP_SIZE + 1)       + threadIdx.x] = x_dm_ik;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
@@ -899,61 +1025,63 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[2];
-    float  dA[mma_C::ne/2][2];
-    float  mA[mma_C::ne/2][2];
+    mma_A   A[ntx][2];
+    float  dA[ntx][mma_C::ne/2][2];
+    float  mA[ntx][mma_C::ne/2][2];
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i = i0 + mma_A::get_i(l);
-        const int shift = 2*mma_A::get_k(l);
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+            const int shift = 2*mma_A::get_k(l);
 
-        A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
-        A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
-    }
+            A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303;
+            A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
+        }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
 #pragma unroll
-        for (int kk = 0; kk < 2; ++kk) {
-            const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
+            for (int kdm = 0; kdm < 2; ++kdm) {
+                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]);
 
-            dA[l][kk] = dm.x;
-            mA[l][kk] = dm.y;
+                dA[n][l][kdm] = dm.x;
+                mA[n][l][kdm] = dm.y;
+            }
         }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C Cd[2];
-        mma_C Cm[2];
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B[2];
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j = j0 + mma_B::get_j(l);
-            const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0)        % WARP_SIZE, MMQ_TILE_Y_K);
+        B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
-            B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -961,9 +1089,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        Cd[0].mma_K4(A[0], B[0]);
-        Cd[1].mma_K4(A[1], B[1]);
-
+        mma_C Cm[2];
         mma_A A1;
         A1.x[0] = 0x01010101;
         A1.x[1] = 0x01010101;
@@ -971,19 +1097,38 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         Cm[1].mma_K4(A1, B[1]);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
+    for (int n = 0; n < ntx; ++n) {
+            mma_C Cd[2];
+
+            Cd[0].mma_K4(A[n][0], B[0]);
+            Cd[1].mma_K4(A[n][1], B[1]);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += (
+                    Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI3_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI3_K;
     const int kqsx = threadIdx.x % QI3_K;
@@ -1015,13 +1160,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
                 continue;
             }
 
-            x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + k/2] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
         }
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
@@ -1033,7 +1181,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1058,16 +1210,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-        x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc;
+#else
+        x_sc[i*(WARP_SIZE/4) + i/4   + threadIdx.x % (WARP_SIZE/4)] = sc;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_df = (const float *) x_dm;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
@@ -1093,69 +1251,72 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI3_K;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[2];
-    int   scA[mma_C::ne/2][2];
-    float  dA[mma_C::ne/2];
+    mma_A   A[ntx][2];
+    int   scA[ntx][mma_C::ne/2][2];
+    float  dA[ntx][mma_C::ne/2];
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i = i0 + mma_A::get_i(l);
-        const int k = QR3_K*k0 + mma_A::get_k(l);
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+            const int k = QR3_K*k0 + mma_A::get_k(l);
 
-        A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0]          >> (4*(k%2))) & 0x0F0F0F0F;
-        A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
-        A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
-        A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
-    }
+            A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0]          >> (4*(k%2))) & 0x0F0F0F0F;
+            A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
+            A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
+            A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
+        }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        const int kbx  = k0 / QI3_K;
-        const int ky  = (k0 % QI3_K) * QR3_K;
-        const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+            const int kbx  = k0 / QI3_K;
+            const int ky  = (k0 % QI3_K) * QR3_K;
+            const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4;
 
-        scA[l][0] = sc[0];
-        scA[l][1] = sc[1];
-    }
+            scA[n][l][0] = sc[0];
+            scA[n][l][1] = sc[1];
+        }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C[2];
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B[2];
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j = j0 + mma_B::get_j(l);
-            const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0)        % WARP_SIZE, MMQ_TILE_Y_K);
+        B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
-            B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -1163,23 +1324,37 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C[0].mma_K4(A[0], B[0]);
-        C[1].mma_K4(A[1], B[1]);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C[2];
+            C[0].mma_K4(A[n][0], B[0]);
+            C[1].mma_K4(A[n][1], B[1]);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+    int   * x_sc = (int   *) (x_dm + WARP_SIZE/QI4_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI4_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
@@ -1194,7 +1369,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_K;  // == 1 if QK_K == 256
@@ -1210,7 +1389,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q4_K       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1231,15 +1414,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
         scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8;
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + ksc] = scales8;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -1262,71 +1452,79 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
+    const int   * x_sc = (const int   *) x_dm + WARP_SIZE/QI4_K;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][2];
+    int   scA[ntx][mma_C::ne/2][2];
+    int    mA[ntx][mma_C::ne/2][2];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    mma_A   A[2];
-    int   scA[mma_C::ne/2][2];
-    int    mA[mma_C::ne/2][2];
-    half2 dmA[mma_C::ne/2];
 #pragma unroll
-    for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+    for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + mma_A::get_i(l);
-            const int k = k0 + mma_A::get_k(l);
+        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) {
+            A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K);
 
-            A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
+#pragma unroll
+            for (int l = 0; l < mma_A::ne; ++l) {
+                A[n][kvdr/4 + 1].x[l]  = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F;
+                A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F;
+            }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l);
+        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
-            const uint8_t *  m = sc + 8;
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8);
+                const uint8_t *  m = sc + 8;
 
-            scA[l][kvdr/4] = sc[kvdr/4];
-            mA[l][kvdr/4]  =  m[kvdr/4];
+                scA[n][l][kvdr/4] = sc[kvdr/4];
+                mA[n][l][kvdr/4]  =  m[kvdr/4];
+            }
         }
-    }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        float tmpd[mma_C::ne] = {0.0f};
-        float tmpm[mma_C::ne] = {0.0f};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmpd[ntx][mma_C::ne] = {{0.0f}};
+        float tmpm[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
-            mma_C   C;
+        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
             mma_B   B;
             half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-            for (int l = 0; l < mma_B::ne; ++l) {
-                const int j = j0 + mma_B::get_j(l);
-                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+            B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
 
-                B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-            }
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
@@ -1334,29 +1532,46 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
                 dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
             }
 
-            C.mma_K8(A[kvdr/4], B);
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][kvdr/4], B);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) *  __low2float(dsB[l%2]);
-                tmpm[l] += mA[l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) *  __low2float(dsB[l%2]);
+                    tmpm[n][l] += mA[n][l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                }
             }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_dm + WARP_SIZE/QI5_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI5_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
@@ -1383,8 +1598,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
 
-        x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
-        x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_K;  // == 1 if QK_K == 256
@@ -1400,7 +1620,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q5_K       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1421,17 +1645,24 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
         scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8;
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + ksc] = scales8;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const int   * y_qs  = (const int   *) y + 4;
-    const half2 * y_ds  = (const half2 *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -1452,71 +1683,70 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_dm + WARP_SIZE/QI5_K;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][2];
+    int   scA[ntx][mma_C::ne/2][2];
+    int    mA[ntx][mma_C::ne/2][2];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    mma_A   A[2];
-    int   scA[mma_C::ne/2][2];
-    int    mA[mma_C::ne/2][2];
-    half2 dmA[mma_C::ne/2];
 #pragma unroll
-    for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+    for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + mma_A::get_i(l);
-            const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
-
-            A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
-        }
+        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+            A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
-            const uint8_t *  m = sc + 8;
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8);
+                const uint8_t *  m = sc + 8;
 
-            scA[l][kvdr/4] = sc[kvdr/4];
-            mA[l][kvdr/4]  =  m[kvdr/4];
+                scA[n][l][kvdr/4] = sc[kvdr/4];
+                mA[n][l][kvdr/4]  =  m[kvdr/4];
+            }
         }
-    }
 
-#pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+    #pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        float tmpd[mma_C::ne] = {0.0f};
-        float tmpm[mma_C::ne] = {0.0f};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmpd[ntx][mma_C::ne] = {{0.0f}};
+        float tmpm[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
         for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
-            mma_C   C;
             mma_B   B;
             half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-            for (int l = 0; l < mma_B::ne; ++l) {
-                const int j = j0 + mma_B::get_j(l);
-                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+            B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
 
-                B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-            }
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
@@ -1524,29 +1754,46 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
                 dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
             }
 
-            C.mma_K8(A[kvdr/4], B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][kvdr/4], B);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) *  __low2float(dsB[l%2]);
-                tmpm[l] += mA[l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) *  __low2float(dsB[l%2]);
+                    tmpm[n][l] += mA[n][l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                }
             }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI6_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
@@ -1573,13 +1820,17 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
         const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
 
-        x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
-        x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
     const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
@@ -1591,7 +1842,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1604,18 +1859,24 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -1629,80 +1890,77 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
                 &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
-                x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
+                x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI6_K;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-#ifdef INT8_MMA_AVAILABLE
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
-#endif // INT8_MMA_AVAILABLE
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][4];
+    int   scA[ntx][mma_C::ne/2][4];
+    float  dA[ntx][mma_C::ne/2];
 
-    mma_A   A[4];
-    int   scA[mma_C::ne/2][4];
-    float  dA[mma_C::ne/2];
 #pragma unroll
-    for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+    for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + mma_A::get_i(l);
-            const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
-
-            A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
-            A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
-        }
+        for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+            A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0),        MMQ_MMA_TILE_X_K_Q6_K);
+            A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+                const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]);
 
-            scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
-            scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+                scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0];
+                scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+            }
         }
-    }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        float tmp[mma_C::ne] = {0.0f};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmp[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
         for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
-            mma_C C[2];
             mma_B B[2];
             float dB[mma_C::ne/2];
 
-#pragma unroll
-            for (int l = 0; l < mma_B::ne; ++l) {
-                const int j = j0 + mma_B::get_j(l);
-                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+            const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE;
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k0B, MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
 
-                B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
-                B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
-            }
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
@@ -1710,22 +1968,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
                 dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
             }
 
-            C[0].mma_K4(A[kvdr/2 + 0], B[0]);
-            C[1].mma_K4(A[kvdr/2 + 1], B[1]);
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][kvdr/2 + 0], B[0]);
+                C[1].mma_K4(A[n][kvdr/2 + 1], B[1]);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2];
+                }
             }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
@@ -1761,28 +2026,37 @@ static __device__ __forceinline__ void mmq_write_back_mma(
 
     typedef mma_int_C_I16J8 mma_C;
 
-    const int i0 = threadIdx.y*mma_C::I;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
 #ifdef INT8_MMA_AVAILABLE
     static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
 #endif // INT8_MMA_AVAILABLE
 
+    dst += (threadIdx.y % ntx) * mma_C::J*stride;
+
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const int j = j0 + mma_C::get_j(l);
 
-            if (j > j_max) {
-                continue;
-            }
+                if (j > j_max) {
+                    continue;
+                }
 
-            const int i = i0 + mma_C::get_i(l);
+                const int i = i0 + n*mma_C::I + mma_C::get_i(l);
 
-            if (need_check && i > i_max) {
-                continue;
-            }
+                if (need_check && i > i_max) {
+                    continue;
+                }
 
-            dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
+                dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+            }
         }
     }
 }
@@ -1910,6 +2184,10 @@ static __device__ void mul_mat_q_process_tile(
     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
 
+    extern __shared__ char data_mul_mat_q[];
+    int * tile_y = (int *) data_mul_mat_q;
+    int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+
 #ifdef INT8_MMA_AVAILABLE
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
     constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
@@ -1918,14 +2196,6 @@ static __device__ void mul_mat_q_process_tile(
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 #endif // INT8_MMA_AVAILABLE
 
-    constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
-
-    extern __shared__ char data_mul_mat_q[];
-    int   * tile_x_qs = (int   *)  data_mul_mat_q;
-    half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
-    int   * tile_x_sc = (int   *) (tile_x_dm + txs.dm);
-    int   * tile_y    = (int   *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
-
     constexpr int blocks_per_warp = WARP_SIZE / qi;
 
     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
@@ -1937,7 +2207,7 @@ static __device__ void mul_mat_q_process_tile(
 
     for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
 
-        load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+        load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
 
 #pragma unroll
         for (int kr = 0; kr < qr; ++kr) {
@@ -1953,7 +2223,7 @@ static __device__ void mul_mat_q_process_tile(
 
 // #pragma unroll // unrolling this loop causes too much register pressure
             for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
-                vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
+                vec_dot(tile_x, tile_y, sum, k0);
             }
 
             __syncthreads();
@@ -1987,7 +2257,7 @@ static __global__ void mul_mat_q(
     const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
 
     // Skip unused template specializations for faster compilation:
-    if (mmq_x > get_mmq_x_max_device()) {
+    if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
         NO_DEVICE_CODE;
         return;
     }
@@ -2139,11 +2409,12 @@ struct mmq_args {
     int64_t ne0;
 };
 
-static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
-    const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
-
-    const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
-    const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
+template<ggml_type type>
+static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
+    const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
+    const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
+    const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
     return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 }
 
@@ -2156,7 +2427,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
     const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
 
-    const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
+    const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
 
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -2225,12 +2496,17 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     int nparts_best = INT_MAX;
 
     for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+        const int granularity = mmq_get_granularity_host(mmq_x, cc);
+
+        if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
+            continue;
+        }
+
         const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
         const int nwaves_xy_tiling = ntiles_x*block_num_y;
-
         const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
 
-        if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
+        if (nparts < nparts_best) {
             mmq_x_best  = mmq_x;
             nparts_best = nparts;
         }