]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: optimize and refactor MMQ (llama/8416)
authorJohannes Gäßler <redacted>
Thu, 11 Jul 2024 14:47:47 +0000 (16:47 +0200)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 19:48:46 +0000 (22:48 +0300)
* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation

ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/quantize.cu
ggml/src/ggml-cuda/quantize.cuh
ggml/src/ggml-cuda/vecdotq.cuh

index 5d87dd8e64c0aa2147ca03dd070e60431fffa8cf..a452a3cc3d152f850d95b8534c10b883bd3e03e2 100644 (file)
@@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
         }
 #endif // defined(INT8_MMA_AVAILABLE)
     }
+
+    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
+        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
+    }
 };
 
 struct mma_int_B_J8K4 {
index 118e34d2847fc8a42237a9eaffd8f057ee202932..51c44d857a4908c88fc7579bd40f50313a290f66 100644 (file)
@@ -8,18 +8,70 @@
 #include <cstdint>
 
 #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
+#define MMQ_ITER_K 256
+#define MMQ_NWARPS 8
 
 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
-typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
 typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
 
+enum mmq_q8_1_ds_layout {
+    MMQ_Q8_1_DS_LAYOUT_D4,
+    MMQ_Q8_1_DS_LAYOUT_DS4,
+    MMQ_Q8_1_DS_LAYOUT_D2S6,
+};
+
 struct block_q8_1_mmq {
-    half2  ds[4];
-    int8_t qs[4*QK8_1];
+    // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
+    // The y float data is first grouped as blocks of 128 values.
+    // These blocks are then treated as individual data values and transposed.
+    //
+    // To avoid shared memory bank conflicts each block is padded with 16 bytes.
+    // This padding is also used to store block scales/partial sums.
+    // The scales multiplied with the quantized data are equal to the unquantized values.
+    // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
+    //     and are only needed for performance reasons.
+    //
+    // The exact data stored depends on the x data type.
+    union {
+        float d4[4];    // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
+        half2 ds4[4];   // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
+        half  d2s6[8];  // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
+                        //     stored as d0,d1,s1,s2,s3,s4,s5
+    };
+    int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
 };
 static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
 static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
 
+static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
+    switch (type_x) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q5_0:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q5_1:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q8_0:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q2_K:
+            return MMQ_Q8_1_DS_LAYOUT_D2S6;
+        case GGML_TYPE_Q3_K:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+}
+
 struct tile_x_sizes {
     int qs;
     int dm;
@@ -79,49 +131,46 @@ static constexpr __device__ int get_mmq_y_device() {
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 }
 
-#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
-#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
-#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
-#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
-#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
-#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE       + mmq_y,       0}
-#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0   + mmq_y/QI4_0,     0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1   + mmq_y/QI4_1,     0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE         + mmq_y,           0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y,                                     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K   + mmq_y/QI4_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K   + mmq_y/QI5_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K   + mmq_y/QI6_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
 
 static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
     return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
         type == GGML_TYPE_Q4_1    ? MMQ_DP4A_TXS_Q4_1 :
-        type == GGML_TYPE_Q5_0    ? MMQ_DP4A_TXS_Q5_0 :
-        type == GGML_TYPE_Q5_1    ? MMQ_DP4A_TXS_Q5_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_DP4A_TXS_Q8_1 :
         type == GGML_TYPE_Q8_0    ? MMQ_DP4A_TXS_Q8_0 :
         type == GGML_TYPE_Q2_K    ? MMQ_DP4A_TXS_Q2_K :
         type == GGML_TYPE_Q3_K    ? MMQ_DP4A_TXS_Q3_K :
         type == GGML_TYPE_Q4_K    ? MMQ_DP4A_TXS_Q4_K :
         type == GGML_TYPE_Q5_K    ? MMQ_DP4A_TXS_Q5_K :
         type == GGML_TYPE_Q6_K    ? MMQ_DP4A_TXS_Q6_K :
-        type == GGML_TYPE_IQ4_XS  ? MMQ_DP4A_TXS_Q5_0 :
-        type == GGML_TYPE_IQ4_NL  ? MMQ_DP4A_TXS_Q5_0 :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_DP4A_TXS_Q8_0 :
         tile_x_sizes{0, 0, 0};
 }
 
-#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0               + 4)
-#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1               + 4)
-#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0               + 4)
-#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1               + 4)
-#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0               + 0)
-#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE                     + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2)
-#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0                   + 4)
+#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1                   + 4)
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE                         + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K     + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K     + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K     + WARP_SIZE/8 + 7)
 
 static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
@@ -131,21 +180,20 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
 static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
     return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
         type == GGML_TYPE_Q4_1    ? MMQ_MMA_TILE_X_K_Q4_1 :
-        type == GGML_TYPE_Q5_0    ? MMQ_MMA_TILE_X_K_Q5_0 :
-        type == GGML_TYPE_Q5_1    ? MMQ_MMA_TILE_X_K_Q5_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
         type == GGML_TYPE_Q8_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
         type == GGML_TYPE_Q2_K    ? MMQ_MMA_TILE_X_K_Q2_K :
         type == GGML_TYPE_Q3_K    ? MMQ_MMA_TILE_X_K_Q3_K :
         type == GGML_TYPE_Q4_K    ? MMQ_MMA_TILE_X_K_Q4_K :
         type == GGML_TYPE_Q5_K    ? MMQ_MMA_TILE_X_K_Q5_K :
         type == GGML_TYPE_Q6_K    ? MMQ_MMA_TILE_X_K_Q6_K :
-        type == GGML_TYPE_IQ4_XS  ? MMQ_MMA_TILE_X_K_Q5_0 :
-        type == GGML_TYPE_IQ4_NL  ? MMQ_MMA_TILE_X_K_Q5_0 :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_MMA_TILE_X_K_Q8_0 :
         0;
 }
 
 #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
-#define MMQ_NWARPS 8
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
     return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
@@ -218,7 +266,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -226,34 +274,39 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
 
-            int u[2*VDR_Q4_0_Q8_1_MMQ];
+                int u[2*VDR_Q4_0_Q8_1_MMQ];
 
 #pragma unroll
-            for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
-            }
+                for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
+                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
+                }
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
-                (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
-                y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
+                     x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
@@ -271,52 +324,60 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A[ntx];
-    float dA[ntx][mma_C::ne/2];
+    mma_A A[ntx][4];
+    float dA[ntx][mma_C::ne/2][4];
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i     = i0 + n*mma_A::I + mma_A::get_i(l);
-            const int k     = k0              + mma_A::get_k(l) % QI4_0;
-            const int shift =                4*(mma_A::get_k(l) / QI4_0);
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
+            const int k0 = k00 + k01;
 
-            A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
-        }
+#pragma unroll
+            for (int l = 0; l < mma_A::ne; ++l) {
+                const int i     = i0 + n*mma_A::I + mma_A::get_i(l);
+                const int k     = k0/QR4_0        + mma_A::get_k(l) % QI4_0;
+                const int shift =                4*(mma_A::get_k(l) / QI4_0);
+
+                A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+            }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0];
+                dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)];
+            }
         }
     }
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        mma_B  B;
-        float dB[mma_C::ne/2];
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
+            mma_B  B;
+            float dB[mma_C::ne/2];
 
-        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
 
-            dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
-        }
+                dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
 
 #pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-            mma_C C;
-            C.mma_K8(A[n], B);
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l];
+                }
             }
         }
     }
@@ -381,7 +442,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -389,34 +450,39 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
 
-            int u[2*VDR_Q4_1_Q8_1_MMQ];
+                int u[2*VDR_Q4_1_Q8_1_MMQ];
 
 #pragma unroll
-            for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
-            }
+                for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
+                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
+                }
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
-                (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
-                y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
+                     x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
@@ -435,50 +501,58 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A[ntx];
-    half2 dmA[ntx][mma_C::ne/2];
+    mma_A A[ntx][4];
+    half2 dmA[ntx][mma_C::ne/2][4];
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
-        ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1);
-        A[n].x[2]  = (A[n].x[0] >> 4) & 0x0F0F0F0F;
-        A[n].x[3]  = (A[n].x[1] >> 4) & 0x0F0F0F0F;
-        A[n].x[0] &= 0x0F0F0F0F;
-        A[n].x[1] &= 0x0F0F0F0F;
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1);
+            A[n][k01/(QR4_1*QI4_1)].x[2]  = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F;
+            A[n][k01/(QR4_1*QI4_1)].x[3]  = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F;
+            A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F;
+            A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F;
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1];
+                dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)];
+            }
         }
     }
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        mma_B B;
-        half2 dsB[mma_C::ne/2];
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
+            mma_B B;
+            half2 dsB[mma_C::ne/2];
 
-        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
 
-            dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
-        }
+                dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
+            }
 
 #pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-            mma_C C;
-            C.mma_K8(A[n], B);
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2];
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+                }
             }
         }
     }
@@ -531,8 +605,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
@@ -553,106 +627,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q5_0       + kbxd] = bxi->d;
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
 #endif // INT8_MMA_AVAILABLE
     }
 }
 
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-
-    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
-    const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + txs.qs;
-    const int   * y_qs = (const int   *) y + 4;
-    const float * y_df = (const float *) y;
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-
-#pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
-                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
-                 x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
-        }
-    }
-}
-
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-#ifdef INT8_MMA_AVAILABLE
-
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
-
-    constexpr int granularity = mmq_get_granularity_device(mmq_x);
-    constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
-    const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
-    const int   * y_qs = (const int   *) y + 4;
-    const float * y_df = (const float *) y;
-
-    mma_A A[ntx];
-    float dA[ntx][mma_C::ne/2];
-
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-#pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
-
-#pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
-
-            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
-        }
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        mma_B B;
-        float dB[mma_C::ne/2];
-
-        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
-
-#pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
-
-            dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
-        }
-
-#pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-            mma_C C;
-            C.mma_K8(A[n], B);
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
-            }
-        }
-    }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
-
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
@@ -694,8 +675,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
@@ -716,41 +697,101 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_dm[i*MMQ_MMA_TILE_X_K_Q5_1       + kbxd] = bxi->dm;
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
 #endif // INT8_MMA_AVAILABLE
     }
 }
 
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_tile + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI8_0;
+    const int kqsx = threadIdx.x % QI8_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
+        x_qs[i*(2*WARP_SIZE + 1)     + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
+        int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = bxi->d;
+#else
+        x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const float * x_df = (const float *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
-    const half2 * y_ds = (const half2 *) y;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
-                 x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
+                     x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
@@ -764,140 +805,106 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
     y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
+    const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
-    const half2 * y_ds = (const half2 *) y;
+    const float * y_df = (const float *) y;
 
-    mma_A A[ntx];
-    half2 dmA[ntx][mma_C::ne/2];
+    mma_A A[ntx][WARP_SIZE/QI8_0];
+    float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
-        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
-
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+            const int k0 = k00 + k01;
 
-            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1];
+            A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
         }
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        mma_B B;
-        half2 dsB[mma_C::ne/2];
-
-        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
-
-            dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
-        }
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
 #pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-            mma_C C;
-            C.mma_K8(A[n], B);
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+                const int k0 = k00 + k01;
 
-#pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+                dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
             }
         }
     }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
 
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+            const int k0 = k00 + k01;
 
-#ifdef INT8_MMA_AVAILABLE
-    int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_tile + WARP_SIZE);
-#else
-    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
-    int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+            mma_B B;
+            float dB[mma_C::ne/2];
 
-    const int kbx  = threadIdx.x / QI8_0;
-    const int kqsx = threadIdx.x % QI8_0;
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
 
-#ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
-#else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
-#endif // INT8_MMA_AVAILABLE
-    }
+                dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+            }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
-    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/QI8_0], B);
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
-        int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = min(i, i_max);
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+                }
+            }
         }
-
-        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
-
-#ifdef INT8_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0         + kbxd] = bxi->d;
+    }
 #else
-        x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
-    }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + txs.qs;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
-    const float * y_df = (const float *) y;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
-                (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
-                y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
@@ -911,49 +918,65 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + WARP_SIZE;
+    const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
-    const float * y_df = (const float *) y;
+    const half2 * y_dm = (const half2 *) y;
 
-    mma_A A[ntx];
-    float dA[ntx][mma_C::ne/2];
+    mma_A   A[ntx][WARP_SIZE/QI8_1];
+    half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
-        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+        }
 
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+                const int k0 = k00 + k01;
+
+                dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1];
+            }
         }
     }
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        mma_B B;
-        float dB[mma_C::ne/2];
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            const int k0 = k00 + k01;
+
+            mma_B   B;
+            half2 dsB[mma_C::ne/2];
 
-        B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
 
-            dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
-        }
+                dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+            }
 
 #pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-            mma_C C;
-            C.mma_K8(A[n], B);
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/QI8_1], B);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
+                for (int l = 0; l < mma_C::ne; ++l) {
+                const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2];
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+                }
             }
         }
     }
@@ -968,44 +991,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 #ifdef INT8_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
 #endif // INT8_MMA_AVAILABLE
 
-    const int kbx  = threadIdx.x / QI2_K;
     const int kqsx = threadIdx.x % QI2_K;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
+        int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
 
         const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
 
 #pragma unroll
         for (int l = 0; l < QR2_K; ++l) {
-            const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
+            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
 
-            int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
-            x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
-            x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
-
-            if (kqsx % QR2_K != 0) {
-                continue;
-            }
+            const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
 #ifdef INT8_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
 #else
-            x_qs[i*(WARP_SIZE + 1)       + k] = x_qs_k;
+            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
 #endif // INT8_MMA_AVAILABLE
         }
 
@@ -1018,44 +1034,68 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 #endif // FAST_FP16_AVAILABLE
 
 #ifdef INT8_MMA_AVAILABLE
-        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik;
+        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
 #else
-        x_dm[i*(WARP_SIZE + 1)       + threadIdx.x] = x_dm_ik;
+        x_dm[i*(WARP_SIZE + 1)       + kqsx] = x_dm_ik;
 #endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     const int   * x_qs = (const int   *) x;
     const half2 * x_dm = (const half2 *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
-    const float * y_df = (const float *) y;
+    const half2 * y_ds = (const half2 *) y;
 
+    float2 y_df[mmq_x/nwarps];
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
 
+        y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+    }
+
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
-                &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
-                &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
+                if (k01 < WARP_SIZE/2) {
+                    constexpr int ns = 2;
+                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+                } else {
+                    constexpr int ns = 1;
+                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+                }
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_A_I16K8 mma_A_K8;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
@@ -1066,74 +1106,107 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
-    const float * y_df = (const float *) y;
+    const half2 * y_ds = (const half2 *) y;
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[ntx][2];
-    float  dA[ntx][mma_C::ne/2][2];
-    float  mA[ntx][mma_C::ne/2][2];
+    mma_A   A[ntx][8];
+    float  dA[ntx][mma_C::ne/2][8];
+    float  mA[ntx][mma_C::ne/2][8];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + n*mma_A::I + mma_A::get_i(l);
-            const int shift = 2*mma_A::get_k(l);
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            const int k0 = k00 + k01;
 
-            A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303;
-            A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
+            ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
         }
+    }
 
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
 #pragma unroll
-            for (int kdm = 0; kdm < 2; ++kdm) {
-                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]);
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
+                const int k0 = k00 + k01;
 
-                dA[n][l][kdm] = dm.x;
-                mA[n][l][kdm] = dm.y;
+                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
+
+                dA[n][l][k01/(QI8_1/2)] = dm.x;
+                mA[n][l][k01/(QI8_1/2)] = dm.y;
             }
         }
     }
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        mma_B B[2];
-        float dB[mma_C::ne/2];
-
-        B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0)        % WARP_SIZE, MMQ_TILE_Y_K);
-        B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
+        float2 dB[mma_C::ne/2];
 
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
 
-            dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
+            dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
         }
 
-        mma_C Cm[2];
-        mma_A A1;
-        A1.x[0] = 0x01010101;
-        A1.x[1] = 0x01010101;
-        Cm[0].mma_K4(A1, B[0]);
-        Cm[1].mma_K4(A1, B[1]);
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            mma_B B[2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+            mma_C Cm[2];
+            if (k01 >= WARP_SIZE * 3/4) {
+                mma_A A1;
+                A1.x[0] = 0x01010101;
+                A1.x[1] = 0x01010101;
+                Cm[0].mma_K4(A1, B[0]);
+                Cm[1].mma_K4(A1, B[1]);
+            }
 
 #pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-            mma_C Cd[2];
+            for (int n = 0; n < ntx; ++n) {
+                mma_C Cd[2];
 
-            Cd[0].mma_K4(A[n][0], B[0]);
-            Cd[1].mma_K4(A[n][1], B[1]);
+                Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += (
-                    Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2];
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
+                    if (k01 >= WARP_SIZE * 3/4) {
+                        tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
+                    }
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+                }
+            }
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
+            float2 sB[mma_C::ne/2];
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+                }
             }
         }
     }
@@ -1149,7 +1222,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 #ifdef INT8_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
-    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI3_K);
+    int   * x_sc = (int   *) (x_df + 1);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
@@ -1157,75 +1230,66 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_sc = (int   *) (x_df + txs.dm);
 #endif // INT8_MMA_AVAILABLE
 
-    const int kbx  = threadIdx.x / QI3_K;
     const int kqsx = threadIdx.x % QI3_K;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
+        int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
 
         const int x_ql_0 = get_int_b2(bxi->qs,    kqsx);
         const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
 
 #pragma unroll
         for (int l = 0; l < QR3_K; ++l) {
-            const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
+            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
 
             const int x_ql_k =  (x_ql_0 >> (2*l))       & 0x03030303;
             const int x_qh_k = ((x_qh_0 >>    l)  << 2) & 0x04040404;
 
-            int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
-            x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
-
-            if (kqsx % 2 != 0) {
-                continue;
-            }
+            const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
 #ifdef INT8_MMA_AVAILABLE
-            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + k/2] = x_qs_k;
+            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
 #endif // INT8_MMA_AVAILABLE
         }
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
-    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
-
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
-        int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
+        int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q3_K       + kbxd] = bxi->d;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d;
 #else
-        x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d;
+        x_df[i]                       = bxi->d;
 #endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
-        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+        int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
 
-        const int ksc = threadIdx.x % (QI3_K/4);
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
 
         const int ksc_low = ksc % (QI3_K/8);
         const int shift_low = 4 * (ksc / (QI3_K/8));
@@ -1238,16 +1302,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
 #ifdef INT8_MMA_AVAILABLE
-        x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc;
+        x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc;
 #else
-        x_sc[i*(WARP_SIZE/4) + i/4   + threadIdx.x % (WARP_SIZE/4)] = sc;
+        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = sc;
 #endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1256,32 +1320,35 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
-            const int kbx  = k0 / QI3_K;
-            const int ky  = (k0 % QI3_K) * QR3_K;
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+                const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
-                &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
-                x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+                    x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_A_I16K8 mma_A_K8;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
@@ -1293,73 +1360,74 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + WARP_SIZE*2;
-    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI3_K;
+    const int   * x_sc = (const int   *) x_df + 1;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[ntx][2];
-    int   scA[ntx][mma_C::ne/2][2];
+    mma_A   A[ntx][8];
+    int   scA[ntx][mma_C::ne/2][8];
     float  dA[ntx][mma_C::ne/2];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + n*mma_A::I + mma_A::get_i(l);
-            const int k = QR3_K*k0 + mma_A::get_k(l);
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
 
-            A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0]          >> (4*(k%2))) & 0x0F0F0F0F;
-            A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
-            A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
-            A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
+            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
         }
 
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            const int kbx  = k0 / QI3_K;
-            const int ky  = (k0 % QI3_K) * QR3_K;
-            const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4;
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+                const int k0 = k00 + k01;
 
-            scA[n][l][0] = sc[0];
-            scA[n][l][1] = sc[1];
-        }
+                const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16];
+                const int8_t * sc = (const int8_t *) &sc_packed;
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+                for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+                    scA[n][l][k01/4 + ksc] = sc[ksc];
+                }
+            }
 
-            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K];
         }
     }
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        mma_B B[2];
-        float dB[mma_C::ne/2];
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+            mma_B B[2];
+            float dB[mma_C::ne/2];
 
-        B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0)        % WARP_SIZE, MMQ_TILE_Y_K);
-        B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
 
-            dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
-        }
+                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+            }
 
 #pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-            mma_C C[2];
-            C[0].mma_K4(A[n][0], B[0]);
-            C[1].mma_K4(A[n][1], B[1]);
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2];
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*
+                        (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]);
+                }
             }
         }
     }
@@ -1451,7 +1519,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1460,26 +1528,31 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
+                const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
-                &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
-                x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
+                    &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                    x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
@@ -1500,35 +1573,40 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[ntx][2];
-    int   scA[ntx][mma_C::ne/2][2];
-    int    mA[ntx][mma_C::ne/2][2];
+    mma_A   A[ntx][4];
+    int   scA[ntx][mma_C::ne/2][4];
+    int    mA[ntx][mma_C::ne/2][4];
     half2 dmA[ntx][mma_C::ne/2];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) {
-            A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K);
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K);
 
 #pragma unroll
             for (int l = 0; l < mma_A::ne; ++l) {
-                A[n][kvdr/4 + 1].x[l]  = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F;
-                A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F;
+                A[n][k01/8 + 1].x[l]  = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F;
+                A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F;
             }
         }
 
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+            const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)];
+            const int  m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)];
 
-                const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8);
-                const uint8_t *  m = sc + 8;
+            const uint8_t * sc = (const uint8_t *) &sc_packed;
+            const uint8_t *  m = (const uint8_t *)  &m_packed;
 
-                scA[n][l][kvdr/4] = sc[kvdr/4];
-                mA[n][l][kvdr/4]  =  m[kvdr/4];
+#pragma unroll
+            for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+                scA[n][l][ksc] = sc[ksc];
+                mA[n][l][ksc]  =  m[ksc];
             }
         }
 
@@ -1536,7 +1614,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K];
         }
     }
 
@@ -1546,28 +1624,28 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
         float tmpm[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             mma_B   B;
             half2 dsB[mma_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
 
-                dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+                dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
                 mma_C C;
-                C.mma_K8(A[n][kvdr/4], B);
+                C.mma_K8(A[n][k01/8], B);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
-                    tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) *  __low2float(dsB[l%2]);
-                    tmpm[n][l] += mA[n][l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                    tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) *  __low2float(dsB[l%2]);
+                    tmpm[n][l] += mA[n][l/2][k01/8]           * __high2float(dsB[l%2]);
                 }
             }
         }
@@ -1682,7 +1760,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1691,26 +1769,31 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
-                &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
-                x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                    x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
@@ -1731,26 +1814,34 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[ntx][2];
-    int   scA[ntx][mma_C::ne/2][2];
-    int    mA[ntx][mma_C::ne/2][2];
+    mma_A   A[ntx][4];
+    int   scA[ntx][mma_C::ne/2][4];
+    int    mA[ntx][mma_C::ne/2][4];
     half2 dmA[ntx][mma_C::ne/2];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
-            A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K);
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K);
+        }
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+            const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)];
+            const int  m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)];
 
-                const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8);
-                const uint8_t *  m = sc + 8;
+            const uint8_t * sc = (const uint8_t *) &sc_packed;
+            const uint8_t *  m = (const uint8_t *)  &m_packed;
 
-                scA[n][l][kvdr/4] = sc[kvdr/4];
-                mA[n][l][kvdr/4]  =  m[kvdr/4];
+#pragma unroll
+            for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+                scA[n][l][ksc] = sc[ksc];
+                mA[n][l][ksc]  =  m[ksc];
             }
         }
 
@@ -1758,7 +1849,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K];
         }
     }
 
@@ -1768,28 +1859,30 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
         float tmpm[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
             mma_B   B;
             half2 dsB[mma_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
 
-                dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+                dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
             }
 
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
                 mma_C C;
-                C.mma_K8(A[n][kvdr/4], B);
+                C.mma_K8(A[n][k01/8], B);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
-                    tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) *  __low2float(dsB[l%2]);
-                    tmpm[n][l] += mA[n][l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                    tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) *  __low2float(dsB[l%2]);
+                    tmpm[n][l] += mA[n][l/2][k01/8]           * __high2float(dsB[l%2]);
                 }
             }
         }
@@ -1896,7 +1989,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1905,26 +1998,31 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-            const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+                const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
 
-            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
-                &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
-                x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+                    x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
@@ -1945,25 +2043,35 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[ntx][4];
-    int   scA[ntx][mma_C::ne/2][4];
+    mma_A   A[ntx][8];
+    int   scA[ntx][mma_C::ne/2][8];
     float  dA[ntx][mma_C::ne/2];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
-            A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0),        MMQ_MMA_TILE_X_K_Q6_K);
-            A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
+            A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+            const int k0 = k00 + k01;
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-                const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]);
+                const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
+                const int8_t * sc        = (const int8_t *) &sc_packed;
 
-                scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0];
-                scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+#pragma unroll
+                for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+                    scA[n][l][k01/4 + ksc] = sc[ksc];
+                }
             }
         }
 
@@ -1971,7 +2079,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
         }
     }
 
@@ -1980,30 +2088,29 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         float tmp[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             mma_B B[2];
             float dB[mma_C::ne/2];
 
-            const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE;
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k0B, MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
 
-                dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
                 mma_C C[2];
-                C[0].mma_K4(A[n][kvdr/2 + 0], B[0]);
-                C[1].mma_K4(A[n][kvdr/2 + 1], B[1]);
+                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
-                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2];
+                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
                 }
             }
         }
@@ -2051,8 +2158,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
@@ -2073,7 +2180,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
 #endif // INT8_MMA_AVAILABLE
@@ -2109,8 +2216,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
@@ -2133,7 +2240,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
 
 #ifdef INT8_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
 #endif // INT8_MMA_AVAILABLE
@@ -2229,16 +2336,16 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
     static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
     static constexpr int              vdr          = VDR_Q5_1_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -2293,45 +2400,18 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
     static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
     static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
-static bool mmq_need_sum(const ggml_type type_x) {
-    switch (type_x) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-            return true;
-        case GGML_TYPE_Q5_0:
-            return false;
-        case GGML_TYPE_Q5_1:
-            return true;
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-            return false;
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-            return true;
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ4_NL:
-            return false;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
-    return false;
-}
-
 template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
 static __device__ void mul_mat_q_process_tile(
     const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@@ -2339,10 +2419,7 @@ static __device__ void mul_mat_q_process_tile(
     const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
 
     constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
-    constexpr int              qr         = ggml_cuda_type_traits<type>::qr;
-    constexpr int              qi         = ggml_cuda_type_traits<type>::qi;
     constexpr int              mmq_y      = get_mmq_y_device();
-    constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
 
     extern __shared__ char data_mul_mat_q[];
@@ -2357,7 +2434,7 @@ static __device__ void mul_mat_q_process_tile(
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 #endif // INT8_MMA_AVAILABLE
 
-    constexpr int blocks_per_warp = WARP_SIZE / qi;
+    constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
@@ -2366,29 +2443,40 @@ static __device__ void mul_mat_q_process_tile(
 
     const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
 
-    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
-
+    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
         load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
 
-#pragma unroll
-        for (int kr = 0; kr < qr; ++kr) {
-            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
+        {
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
 #pragma unroll
             for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
                 int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
 
                 tile_y[l] = by0[l];
             }
+        }
 
-            __syncthreads();
+        __syncthreads();
 
-// #pragma unroll // unrolling this loop causes too much register pressure
-            for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
-                vec_dot(tile_x, tile_y, sum, k0);
-            }
+        vec_dot(tile_x, tile_y, sum, 0);
+
+        __syncthreads();
+
+        {
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
 
-            __syncthreads();
+                tile_y[l] = by0[l];
+            }
         }
+
+        __syncthreads();
+
+        vec_dot(tile_x, tile_y, sum, WARP_SIZE);
+
+        __syncthreads();
     }
 
     if (fixup) {
@@ -2424,7 +2512,6 @@ static __global__ void mul_mat_q(
     }
 
     constexpr int qk    = ggml_cuda_type_traits<type>::qk;
-    constexpr int qi    = ggml_cuda_type_traits<type>::qi;
     constexpr int mmq_y = get_mmq_y_device();
 
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
@@ -2439,7 +2526,7 @@ static __global__ void mul_mat_q(
 #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
 
     const     int64_t blocks_per_ne00 = ne00 / qk;
-    constexpr int     blocks_per_warp = WARP_SIZE / qi;
+    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
 
     const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
     const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
@@ -2448,8 +2535,8 @@ static __global__ void mul_mat_q(
     int64_t kbc      = (int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x;
     int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
 
-    kbc      -= (kbc      % blocks_per_ne00) % blocks_per_warp;
-    kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
+    kbc      -= (kbc      % blocks_per_ne00) % blocks_per_iter;
+    kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
 
     // kb0 == k index when doing the matrix multiplication for an output tile.
     int kb0_start = kbc % blocks_per_ne00;
@@ -2490,8 +2577,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
 
     constexpr int     mmq_y           = get_mmq_y_device();
     constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
-    constexpr int     qi              = ggml_cuda_type_traits<type>::qi;
-    constexpr int     blocks_per_warp = WARP_SIZE / qi;
+    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
     const     int64_t blocks_per_ne00 = ne00 / qk;
 
     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
@@ -2501,15 +2587,18 @@ static __global__ void mul_mat_q_stream_k_fixup(
 
     bool any_fixup = false;
 
-    const int bidx_start = (blockIdx.y*nty + blockIdx.x)     * block_num_mmq / (gridDim.y*gridDim.x);
-    const int bidx_stop  = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
+    const int bidx_start = ((blockIdx.y*nty + blockIdx.x)     * block_num_mmq)                           / (gridDim.y*gridDim.x);
+    const int bidx_stop  = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
+
+    int64_t kbc_0;
+    int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
 
     for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
-        int64_t kbc      = (int64_t) bidx     *blocks_per_ne00*ntx*nty / block_num_mmq;
-        int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
+        kbc_0 = kbc_stop_0;
+        kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
 
-        kbc      -= (kbc      % blocks_per_ne00) % blocks_per_warp;
-        kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
+        const int64_t kbc      = kbc_0      - (kbc_0      % blocks_per_ne00) % blocks_per_iter;
+        const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
 
         // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
         if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
index b4678682238d31c1007071b30cd1f79f21dba398..aa7f1eff0e6a25d1e3d8b9c1a892fcd7f916db57 100644 (file)
@@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
     reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
-template <bool need_sum>
+template <mmq_q8_1_ds_layout ds_layout>
 static __global__ void quantize_mmq_q8_1(
     const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
 
-    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+    constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+    constexpr int vals_per_sum   = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+    const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
 
     if (ix0 >= kx0_padded) {
         return;
     }
 
+    const float4 * x4 = (const float4 *) x;
+
     const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
 
     block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
 
-    const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
-    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;              // block index in channel
-    const int64_t iqs = ix0 % (4*QK8_1);                                       // quant index in block
-
-    const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
-    float amax = fabsf(xi);
-
-    amax = warp_reduce_max(amax);
+    const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
+    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;                   // block index in channel
+    const int64_t iqs = ix0 % (4*QK8_1);                                            // quant index in block
+
+    // Load 4 floats per thread and calculate max. abs. value between them:
+    const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+    float amax = fabsf(xi.x);
+    amax = fmaxf(amax, fabsf(xi.y));
+    amax = fmaxf(amax, fabsf(xi.z));
+    amax = fmaxf(amax, fabsf(xi.w));
+
+    // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+    for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+    }
 
     float sum;
-    if (need_sum) {
-        sum = warp_reduce_sum(xi);
+    if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+        sum = xi.x + xi.y + xi.z + xi.w;
+
+        // Exchange calculate sum across vals_per_sum/4 threads.
+#pragma unroll
+        for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
+            sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+        }
     }
 
-    const float d = amax / 127;
-    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+    const float d_inv = 127.0f / amax;
+    char4 q;
+    q.x = roundf(xi.x*d_inv);
+    q.y = roundf(xi.y*d_inv);
+    q.z = roundf(xi.z*d_inv);
+    q.w = roundf(xi.w*d_inv);
 
-    y[ib].qs[iqs] = q;
+    // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+    char4 * yqs4 = (char4 *) y[ib].qs;
+    yqs4[iqs/4] = q;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+        if (iqs % 16 != 0 || iqs >= 96) {
+            return;
+        }
+
+        y[ib].d2s6[2 + iqs/16] = sum;
+
+        if (iqs % 64 != 0) {
+            return;
+        }
+
+        const float d = 1.0f / d_inv;
+
+        y[ib].d2s6[iqs/64] = d;
 
-    if (iqs % QK8_1 != 0) {
         return;
     }
 
-    if (need_sum) {
-        y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
+    if (iqs % 32 != 0) {
+        return;
+    }
+
+    const float d = 1.0f / d_inv;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+        y[ib].ds4[iqs/32] = make_half2(d, sum);
     } else {
-        ((float *) y[ib].ds)[iqs/QK8_1] = d;
+        y[ib].d4[iqs/32]  = d;
     }
 }
 
@@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
 
     GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
 
-    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
     const dim3 num_blocks(block_num_x, kx1, channels);
-    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
-    if (mmq_need_sum(type_x)) {
-        quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
-    } else {
-        quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+    switch (mmq_get_q8_1_ds_layout(type_x)) {
+        case MMQ_Q8_1_DS_LAYOUT_D4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_DS4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_D2S6:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
     }
 }
index 486c9360a46fdc782614e052e2524cf12a27eb7a..03bf322b95873613a1b7dedd6812da1c9eb84868 100644 (file)
@@ -5,7 +5,11 @@
 
 #include <cstdint>
 
-#define CUDA_QUANTIZE_BLOCK_SIZE 256
+#define CUDA_QUANTIZE_BLOCK_SIZE     256
+#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
+
+static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk of out-of-bounds access.");
+static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
 
 typedef void (*quantize_cuda_t)(
     const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
index 1d510484ac4742931fdb04daaabd2bfbea8a25c5..6a17d0f3eb0d29155f8f524968c13b4858c82e74 100644 (file)
@@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
 }
 
 #define VDR_Q2_K_Q8_1_MMVQ 1
-#define VDR_Q2_K_Q8_1_MMQ  2
+#define VDR_Q2_K_Q8_1_MMQ  4
 
 // contiguous v/x values
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
@@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
     return dm2f.x*sumf_d - dm2f.y*sumf_m;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
+template <int ns8>
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
-    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
+    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
 
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
+    float sumf    = 0.0f;
+    float sumf_d8 = 0.0f;
 
 #pragma unroll
-    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
-        const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
-        int sumi_d = 0;
-        int sumi_m = 0;
+    for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
+        const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
+        int sumi_d0 = 0;
+
+        const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
+        int sumi_d1 = 0;
 
-        const int vi0 = v[i0/(QI8_1/2)];
 #pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
-            sumi_d = ggml_cuda_dp4a(vi,         u[i], sumi_d); // SIMD dot product
-            sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
+            sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
+        }
+        sumf_d8 += dm2f0.x * sumi_d0;
+
+#pragma unroll
+        for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+            sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
         }
+        sumf_d8 += dm2f1.x * sumi_d1;
 
-        sumf_d += dm2f.x * sumi_d;
-        sumf_m += dm2f.y * sumi_m;
+        if (i0/QI8_1 < ns8) {
+            const float2 s8f = __half22float2(s8[i0/QI8_1]);
+            sumf -= dm2f0.y*s8f.x;
+            sumf -= dm2f1.y*s8f.y;
+        } else {
+            int sumi_m0 = 0;
+#pragma unroll
+            for (int i = i0; i < i0 + QI8_1/2; ++i) {
+                sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
+            }
+            sumf_d8 -= dm2f0.y * sumi_m0;
+
+            int sumi_m1 = 0;
+#pragma unroll
+            for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+                sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
+            }
+            sumf_d8 -= dm2f1.y * sumi_m1;
+        }
     }
 
-    return d8*(sumf_d - sumf_m);
+    return sumf + d8*sumf_d8;
 }
 
 #define VDR_Q3_K_Q8_1_MMVQ 1
@@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
     return d3 * sumf;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
     const float & d3, const float & d8) {
@@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
 
 #pragma unroll
         for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
-            sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
+            sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
         }
 
         sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
     return dm4f.x*sumf_d - dm4f.y*sumf_m;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
     const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
     return dm5f.x*sumf_d - dm5f.y*sumf_m;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
     const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
     return d*sumf;
 }
 
-// contiguous u/y values
+// contiguous v/x + u/y values
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
     const float & d6, const float * __restrict__ d8) {
 
     float sumf_d = 0.0f;
 
+    const int      sc_packed = get_int_b4(sc, 0);
+    const int8_t * sc_reg    = (const int8_t *) &sc_packed;
+
 #pragma unroll
     for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
         int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
@@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
             sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
         }
 
-        sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
+        sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
     }
 
     return d6 * sumf_d;