]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: MMQ code deduplication + iquant support (llama/8495)
authorJohannes Gäßler <redacted>
Sat, 20 Jul 2024 20:25:26 +0000 (22:25 +0200)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 19:48:46 +0000 (22:48 +0300)
* CUDA: MMQ code deduplication + iquant support

* 1 less parallel job for CI build

ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu [new file with mode: 0644]
ggml/src/ggml-cuda/vecdotq.cuh

index 3af2eec2f4163277d5d1df90b75d1958c6a8686e..84f6387e2491a398522995d818cd54e48e5f390f 100644 (file)
@@ -59,6 +59,24 @@ void ggml_cuda_op_mul_mat_q(
         case GGML_TYPE_Q6_K:
             mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
             break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
+            break;
         case GGML_TYPE_IQ4_XS:
             mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
             break;
@@ -93,6 +111,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         case GGML_TYPE_Q4_K:
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ1_S:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
             mmq_supported = true;
index 51c44d857a4908c88fc7579bd40f50313a290f66..f08a4758d44fd932c19af5e968e5948ab81b29a2 100644 (file)
@@ -63,6 +63,14 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
         case GGML_TYPE_Q5_K:
             return MMQ_Q8_1_DS_LAYOUT_DS4;
         case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ3_S:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_IQ1_S:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
             return MMQ_Q8_1_DS_LAYOUT_D4;
@@ -131,15 +139,16 @@ static constexpr __device__ int get_mmq_y_device() {
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 }
 
-#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0   + mmq_y/QI4_0,     0}
-#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1   + mmq_y/QI4_1,     0}
-#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
-#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
-#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE         + mmq_y,           0}
-#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y,                                     mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K   + mmq_y/QI4_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K   + mmq_y/QI5_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K   + mmq_y/QI6_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0   + mmq_y/QI4_0,     0}
+#define MMQ_DP4A_TXS_Q4_1    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1   + mmq_y/QI4_1,     0}
+#define MMQ_DP4A_TXS_Q8_0    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE         + mmq_y,           0}
+#define MMQ_DP4A_TXS_Q3_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y,                                     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K,                     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K   + mmq_y/QI5_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K   + mmq_y/QI6_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
 
 static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
     return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
@@ -152,42 +161,46 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
         type == GGML_TYPE_Q4_K    ? MMQ_DP4A_TXS_Q4_K :
         type == GGML_TYPE_Q5_K    ? MMQ_DP4A_TXS_Q5_K :
         type == GGML_TYPE_Q6_K    ? MMQ_DP4A_TXS_Q6_K :
+        type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ2_XS  ? MMQ_DP4A_TXS_Q8_0_16 :
+        type == GGML_TYPE_IQ2_S   ? MMQ_DP4A_TXS_Q8_0_16 :
+        type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ3_S   ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ1_S   ? MMQ_DP4A_TXS_Q8_0 :
         type == GGML_TYPE_IQ4_XS  ? MMQ_DP4A_TXS_Q8_0 :
         type == GGML_TYPE_IQ4_NL  ? MMQ_DP4A_TXS_Q8_0 :
         tile_x_sizes{0, 0, 0};
 }
 
-#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0                   + 4)
-#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1                   + 4)
 #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
 #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
 #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE                         + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K     + WARP_SIZE/8 + 7)
-#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K     + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2                       + 4)
 #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K     + WARP_SIZE/8 + 7)
 
-static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
-static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
 
 static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
-    return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
-        type == GGML_TYPE_Q4_1    ? MMQ_MMA_TILE_X_K_Q4_1 :
+    return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q4_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
         type == GGML_TYPE_Q5_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
         type == GGML_TYPE_Q5_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
         type == GGML_TYPE_Q8_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
         type == GGML_TYPE_Q2_K    ? MMQ_MMA_TILE_X_K_Q2_K :
         type == GGML_TYPE_Q3_K    ? MMQ_MMA_TILE_X_K_Q3_K :
-        type == GGML_TYPE_Q4_K    ? MMQ_MMA_TILE_X_K_Q4_K :
-        type == GGML_TYPE_Q5_K    ? MMQ_MMA_TILE_X_K_Q5_K :
+        type == GGML_TYPE_Q4_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q5_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
         type == GGML_TYPE_Q6_K    ? MMQ_MMA_TILE_X_K_Q6_K :
+        type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ2_XS  ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_IQ2_S   ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ3_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ1_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
         type == GGML_TYPE_IQ4_XS  ? MMQ_MMA_TILE_X_K_Q8_0 :
         type == GGML_TYPE_IQ4_NL  ? MMQ_MMA_TILE_X_K_Q8_0 :
         0;
@@ -216,7 +229,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 #ifdef INT8_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE);
+    float * x_df = (float *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
@@ -235,11 +248,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         }
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+        const int qs0 = get_int_b2(bxi->qs, kqsx);
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 #else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
 #endif // INT8_MMA_AVAILABLE
     }
 
@@ -257,7 +272,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q4_0       + kbxd] = bxi->d;
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
 #endif // INT8_MMA_AVAILABLE
@@ -304,95 +319,12 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
     }
 }
 
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
-
-    constexpr int granularity = mmq_get_granularity_device(mmq_x);
-    constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
-    const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + WARP_SIZE;
-    const int   * y_qs = (const int   *) y + 4;
-    const half2 * y_ds = (const half2 *) y;
-
-    mma_A A[ntx][4];
-    float dA[ntx][mma_C::ne/2][4];
-
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-#pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
-            const int k0 = k00 + k01;
-
-#pragma unroll
-            for (int l = 0; l < mma_A::ne; ++l) {
-                const int i     = i0 + n*mma_A::I + mma_A::get_i(l);
-                const int k     = k0/QR4_0        + mma_A::get_k(l) % QI4_0;
-                const int shift =                4*(mma_A::get_k(l) / QI4_0);
-
-                A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
-            }
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
-                dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)];
-            }
-        }
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
-            mma_B  B;
-            float dB[mma_C::ne/2];
-
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
-
-                dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
-            }
-
-#pragma unroll
-            for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B);
-
-#pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l];
-                }
-            }
-        }
-    }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
-
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
 #ifdef INT8_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
@@ -411,11 +343,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         }
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+        const int qs0 = get_int_b4(bxi->qs, kqsx);
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
 #endif // INT8_MMA_AVAILABLE
     }
 
@@ -433,7 +367,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_dm[i*MMQ_MMA_TILE_X_K_Q4_1       + kbxd] = bxi->dm;
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
 #endif // INT8_MMA_AVAILABLE
@@ -480,88 +414,6 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
     }
 }
 
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_A_I16K4 mma_A_K4;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
-
-    constexpr int granularity = mmq_get_granularity_device(mmq_x);
-    constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
-    const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
-    const int   * y_qs = (const int   *) y + 4;
-    const half2 * y_ds = (const half2 *) y;
-
-    mma_A A[ntx][4];
-    half2 dmA[ntx][mma_C::ne/2][4];
-
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-#pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
-            const int k0 = k00 + k01;
-
-            A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1);
-            A[n][k01/(QR4_1*QI4_1)].x[2]  = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F;
-            A[n][k01/(QR4_1*QI4_1)].x[3]  = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F;
-            A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F;
-            A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F;
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
-                dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)];
-            }
-        }
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
-            mma_B B;
-            half2 dsB[mma_C::ne/2];
-
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
-
-                dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
-            }
-
-#pragma unroll
-            for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B);
-
-#pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2];
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
-                }
-            }
-        }
-    }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
-}
-
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
@@ -789,10 +641,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     }
 }
 
-template <int mmq_x, int mmq_y, int nwarps>
+template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -808,6 +659,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
+    const half2 * y_ds = (const half2 *) y;
 
     mma_A A[ntx][WARP_SIZE/QI8_0];
     float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
@@ -840,18 +692,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
-            const int k0 = k00 + k01;
-
-            mma_B B;
+            mma_B  B;
             float dB[mma_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
 
-                dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+                if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+                    dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+                } else {
+                    dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+                }
             }
 
 #pragma unroll
@@ -866,10 +720,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             }
         }
     }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
@@ -905,7 +755,6 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
@@ -922,8 +771,8 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_dm = (const half2 *) y;
 
-    mma_A   A[ntx][WARP_SIZE/QI8_1];
-    half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+    mma_A    A[ntx][WARP_SIZE/QI8_1];
+    float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
@@ -944,7 +793,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
                 const int k0 = k00 + k01;
 
-                dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1];
+                dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
             }
         }
     }
@@ -953,18 +802,16 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
-            const int k0 = k00 + k01;
-
-            mma_B   B;
-            half2 dsB[mma_C::ne/2];
+            mma_B    B;
+            float2 dsB[mma_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
 
-                dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
+                dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
 
 #pragma unroll
@@ -974,8 +821,120 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
-                const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2];
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+                }
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0],
+                    &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+                    y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_A_I16K8 mma_A_K8;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][8];
+    float  dA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
+            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
+                const int k0 = k00 + k01;
+
+                dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+            mma_B B[2];
+            float dB[mma_C::ne/2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
                 }
             }
         }
@@ -1222,7 +1181,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 #ifdef INT8_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
-    int   * x_sc = (int   *) (x_df + 1);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
@@ -1262,23 +1220,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         }
     }
 
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
-        int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
-
-#ifdef INT8_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d;
-#else
-        x_df[i]                       = bxi->d;
-#endif // INT8_MMA_AVAILABLE
-    }
-
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
         int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
@@ -1302,11 +1243,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
 #ifdef INT8_MMA_AVAILABLE
-        x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc;
+        const int8_t * sc8 = (const int8_t *) &sc;
+        const float d = bxi->d;
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
+        }
 #else
-        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = sc;
+        x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
 #endif // INT8_MMA_AVAILABLE
     }
+
+#ifndef INT8_MMA_AVAILABLE
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
+        int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+        x_df[i] = bxi->d;
+    }
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
@@ -1342,99 +1304,14 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
     }
 }
 
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_A_I16K8 mma_A_K8;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
-
-    constexpr int granularity = mmq_get_granularity_device(mmq_x);
-    constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
-    const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
-    const int   * x_sc = (const int   *) x_df + 1;
-    const int   * y_qs = (const int   *) y + 4;
-    const float * y_df = (const float *) y;
-
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-    mma_A   A[ntx][8];
-    int   scA[ntx][mma_C::ne/2][8];
-    float  dA[ntx][mma_C::ne/2];
-
-#pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
-            const int k0 = k00 + k01;
-
-            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
-        }
-
-#pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
-#pragma unroll
-            for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
-                const int k0 = k00 + k01;
-
-                const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16];
-                const int8_t * sc = (const int8_t *) &sc_packed;
-
-#pragma unroll
-                for (int ksc = 0; ksc < sizeof(int); ++ksc) {
-                    scA[n][l][k01/4 + ksc] = sc[ksc];
-                }
-            }
-
-            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K];
-        }
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
-            mma_B B[2];
-            float dB[mma_C::ne/2];
-
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
-
-                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
-            }
-
-#pragma unroll
-            for (int n = 0; n < ntx; ++n) {
-                mma_C C[2];
-                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
-
-#pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*
-                        (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]);
-                }
-            }
-        }
-    }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
+    // scale arrangement after the following two lines:
+    //   - ksc == 0: sc0, sc1, sc2, sc3
+    //   - ksc == 1: sc4, sc5, sc6, sc7
+    //   - ksc == 2:  m0,  m1,  m2,  m3
+    //   - ksc == 3:  m4,  m5,  m6,  m7
+    return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
+           ((scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030);  // upper 2 bits
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
@@ -1442,8 +1319,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 #ifdef INT8_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
-    int   * x_sc = (int   *) (x_dm + WARP_SIZE/QI4_K);
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
@@ -1451,9 +1327,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_sc = (int   *) (x_dm + txs.dm);
 #endif // INT8_MMA_AVAILABLE
 
-    const int kbx  = 0;           // threadIdx.x / QI4_K
-    const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
-
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + threadIdx.y;
@@ -1462,33 +1335,59 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             i = min(i, i_max);
         }
 
-        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+        const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
 #endif // INT8_MMA_AVAILABLE
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K;  // == 1 if QK_K == 256
-    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+#ifdef INT8_MMA_AVAILABLE
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
-        int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+        const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+        const uint8_t * sc8 = (const uint8_t *) &sc32;
+        const uint8_t *  m8 = (const uint8_t *)  &m32;
+
+        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+        }
+    }
 
-#ifdef INT8_MMA_AVAILABLE
-        x_dm[i*MMQ_MMA_TILE_X_K_Q4_K       + kbxd] = bxi->dm;
 #else
-        x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
+        int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+        x_dm[i] = bxi->dm;
     }
 
 #pragma unroll
@@ -1504,17 +1403,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int * scales = (const int *) bxi->scales;
 
         const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int scales8 = unpack_scales_q45_K(scales, ksc);
 
-        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
-        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
-        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
-
-#ifdef INT8_MMA_AVAILABLE
-        x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8;
-#else
-        x_sc[i*(WARP_SIZE/8) + i/8   + ksc] = scales8;
-#endif // INT8_MMA_AVAILABLE
+        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
@@ -1544,124 +1437,10 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 
                 sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
                     &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
-                    x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
-            }
-        }
-    }
-}
-
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
-
-    constexpr int granularity = mmq_get_granularity_device(mmq_x);
-    constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
-    const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
-    const int   * x_sc = (const int   *) x_dm + WARP_SIZE/QI4_K;
-    const int   * y_qs = (const int   *) y + 4;
-    const half2 * y_ds = (const half2 *) y;
-
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-    mma_A   A[ntx][4];
-    int   scA[ntx][mma_C::ne/2][4];
-    int    mA[ntx][mma_C::ne/2][4];
-    half2 dmA[ntx][mma_C::ne/2];
-
-#pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
-            const int k0 = k00 + k01;
-
-            A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K);
-
-#pragma unroll
-            for (int l = 0; l < mma_A::ne; ++l) {
-                A[n][k01/8 + 1].x[l]  = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F;
-                A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F;
-            }
-        }
-
-#pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
-            const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)];
-            const int  m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)];
-
-            const uint8_t * sc = (const uint8_t *) &sc_packed;
-            const uint8_t *  m = (const uint8_t *)  &m_packed;
-
-#pragma unroll
-            for (int ksc = 0; ksc < sizeof(int); ++ksc) {
-                scA[n][l][ksc] = sc[ksc];
-                mA[n][l][ksc]  =  m[ksc];
+                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
-
-#pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
-
-            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K];
-        }
     }
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float tmpd[ntx][mma_C::ne] = {{0.0f}};
-        float tmpm[ntx][mma_C::ne] = {{0.0f}};
-
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
-            mma_B   B;
-            half2 dsB[mma_C::ne/2];
-
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
-
-                dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
-            }
-
-#pragma unroll
-            for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/8], B);
-
-#pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) *  __low2float(dsB[l%2]);
-                    tmpm[n][l] += mA[n][l/2][k01/8]           * __high2float(dsB[l%2]);
-                }
-            }
-        }
-
-#pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-#pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
-            }
-        }
-    }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
@@ -1670,7 +1449,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 #ifdef INT8_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
-    int   * x_sc = (int   *) (x_dm + WARP_SIZE/QI5_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
@@ -1678,9 +1456,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_sc = (int   *) (x_dm + txs.dm);
 #endif // INT8_MMA_AVAILABLE
 
-    const int kbx  = 0;           // threadIdx.x / QI5_K
-    const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
-
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + threadIdx.y;
@@ -1689,73 +1464,91 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             i = min(i, i_max);
         }
 
-        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
-        const int ky = QR5_K*kqsx;
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+        const int ky = QR5_K*threadIdx.x;
 
-        const int ql = get_int_b4(bxi->qs, kqsx);
+        const int ql = get_int_b4(bxi->qs, threadIdx.x);
         const int ql0 = (ql >> 0) & 0x0F0F0F0F;
         const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 
-        const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
-        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
-        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+        const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
 
         const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
-        const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
+        const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0;
-        x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
         x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
 #endif // INT8_MMA_AVAILABLE
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K;  // == 1 if QK_K == 256
-    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+#ifdef INT8_MMA_AVAILABLE
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
-        int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+        const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+        const uint8_t * sc8 = (const uint8_t *) &sc32;
+        const uint8_t *  m8 = (const uint8_t *)  &m32;
+
+        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+        }
+    }
 
-#ifdef INT8_MMA_AVAILABLE
-        x_dm[i*MMQ_MMA_TILE_X_K_Q5_K       + kbxd] = bxi->dm;
 #else
-        x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
+        int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        x_dm[i] = bxi->dm;
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+        int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
 
         const int * scales = (const int *) bxi->scales;
 
         const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int scales8 = unpack_scales_q45_K(scales, ksc);
 
-        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
-        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
-        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
-
-#ifdef INT8_MMA_AVAILABLE
-        x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8;
-#else
-        x_sc[i*(WARP_SIZE/8) + i/8   + ksc] = scales8;
-#endif // INT8_MMA_AVAILABLE
+        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
+#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
@@ -1771,134 +1564,24 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 
 // #pragma unroll
     for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
-        const int k0 = k00 + k01;
-
-#pragma unroll
-        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-#pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-                const int i = i0 + threadIdx.x;
-
-                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
-
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
-                    &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
-                    x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
-            }
-        }
-    }
-}
-
-template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
-
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
-
-    constexpr int granularity = mmq_get_granularity_device(mmq_x);
-    constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
-
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
-
-    const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
-    const int   * x_sc = (const int   *) x_dm + WARP_SIZE/QI5_K;
-    const int   * y_qs = (const int   *) y + 4;
-    const half2 * y_ds = (const half2 *) y;
-
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
-
-    mma_A   A[ntx][4];
-    int   scA[ntx][mma_C::ne/2][4];
-    int    mA[ntx][mma_C::ne/2][4];
-    half2 dmA[ntx][mma_C::ne/2];
-
-#pragma unroll
-    for (int n = 0; n < ntx; ++n) {
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
-            const int k0 = k00 + k01;
-
-            A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K);
-        }
-
-#pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
-            const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)];
-            const int  m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)];
-
-            const uint8_t * sc = (const uint8_t *) &sc_packed;
-            const uint8_t *  m = (const uint8_t *)  &m_packed;
-
-#pragma unroll
-            for (int ksc = 0; ksc < sizeof(int); ++ksc) {
-                scA[n][l][ksc] = sc[ksc];
-                mA[n][l][ksc]  =  m[ksc];
-            }
-        }
-
-    #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
-
-            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K];
-        }
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float tmpd[ntx][mma_C::ne] = {{0.0f}};
-        float tmpm[ntx][mma_C::ne] = {{0.0f}};
-
-#pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
-            const int k0 = k00 + k01;
-
-            mma_B   B;
-            half2 dsB[mma_C::ne/2];
-
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
-
-#pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
-
-                dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
-            }
+        const int k0 = k00 + k01;
 
 #pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/8], B);
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) *  __low2float(dsB[l%2]);
-                    tmpm[n][l] += mA[n][l/2][k01/8]           * __high2float(dsB[l%2]);
-                }
-            }
-        }
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-#pragma unroll
-        for (int n = 0; n < ntx; ++n) {
-#pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
-#else
-    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
-    NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
@@ -1915,9 +1598,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_sc = (int   *) (x_df + txs.dm);
 #endif // INT8_MMA_AVAILABLE
 
-    const int kbx  = 0;           // threadIdx.x / QI6_K
-    const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
-
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + threadIdx.y;
@@ -1926,19 +1606,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             i = min(i, i_max);
         }
 
-        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
-        const int ky = QR6_K*kqsx;
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
 
-        const int ql = get_int_b2(bxi->ql, kqsx);
+        const int ql = get_int_b2(bxi->ql, threadIdx.x);
         const int ql0 = (ql >> 0) & 0x0F0F0F0F;
         const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 
-        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
-        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
-        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
+        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
+        const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
+        const int qh1 =  (qh >> ((threadIdx.x & 0x08) >> 2))       & 0x30303030;
 
-        const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
-        const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
+        const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
+        const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
 
 #ifdef INT8_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
@@ -2187,6 +1866,358 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     }
 }
 
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_XXS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
+
+        const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
+        const uint8_t * aux8 = (const uint8_t *) &q2;
+        const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
+
+#pragma unroll
+        for (int l = 0; l < QR2_XXS; ++l) {
+            const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
+            const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+
+            const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+            const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+
+            const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+            const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid0;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = aux32 >> 28;
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_XS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
+
+        const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint16_t * q2 = (const uint16_t *) &q2_packed;
+
+    #pragma unroll
+        for (int l = 0; l < QR2_XS; ++l) {
+            const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
+            const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l] >> 9));
+
+            const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+            const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = bxi->scales[kqsx];
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#else
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_S/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
+
+        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+        const int       signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
+        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+        for (int l = 0; l < QR2_S; ++l) {
+            const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
+
+            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+            const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+            const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = bxi->scales[kqsx];
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#else
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI3_XXS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
+
+        const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint8_t * q3 = (const uint8_t *) &q3_packed;
+        const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
+
+#pragma unroll
+        for (int l = 0; l < QR3_XXS; ++l) {
+            const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+
+            const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+
+            const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+            const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = aux32 >> 28;
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/2;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI3_S/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
+
+        const int2      qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+        const int       signs_packed_32 = get_int_b2(bxi->signs, kqsx);
+        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+        for (int l = 0; l < QR3_S; ++l) {
+            const int2 grid_pos = make_int2(
+                iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
+                iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
+
+            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = ls*d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_ds = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % QI1_S;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
+        int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
+
+        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+    #pragma unroll
+        for (int l = 0; l < QR1_S/2; ++l) {
+            const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
+
+            const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+            const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid0;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
+        const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#else
+        x_ds[i*(WARP_SIZE/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
@@ -2320,7 +2351,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
     static constexpr int              vdr          = VDR_Q4_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2328,7 +2359,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
     static constexpr int              vdr          = VDR_Q4_1_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2336,7 +2367,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
     static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2352,7 +2383,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
     static constexpr int              vdr          = VDR_Q8_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2368,7 +2399,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
     static constexpr int              vdr          = VDR_Q3_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2376,7 +2407,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
     static constexpr int              vdr          = VDR_Q4_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2384,7 +2415,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
     static constexpr int              vdr          = VDR_Q5_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2396,11 +2427,59 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
+    static constexpr int              vdr          = VDR_IQ2_XXS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
+    static constexpr int              vdr          = VDR_IQ2_XS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
+    static constexpr int              vdr          = VDR_IQ2_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
+    static constexpr int              vdr          = VDR_IQ3_XXS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
+    static constexpr int              vdr          = VDR_IQ3_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
+    static constexpr int              vdr          = VDR_IQ1_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
     static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2408,7 +2487,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
     static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
@@ -2837,6 +2916,12 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
 extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
 extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
 
index ffeb3c27df2f43bc390385a41e90bf71cbca2c48..d7874e6eaf83280d4e48438b4fa9f639723749f8 100755 (executable)
@@ -23,7 +23,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}
 TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
     "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
-    "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
+    "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
+    "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
 ]
 
 SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu
new file mode 100644 (file)
index 0000000..84ec850
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu
new file mode 100644 (file)
index 0000000..583c4e5
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu
new file mode 100644 (file)
index 0000000..edaf156
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu
new file mode 100644 (file)
index 0000000..233d934
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu
new file mode 100644 (file)
index 0000000..6092dc7
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu
new file mode 100644 (file)
index 0000000..1d5bd20
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
index 6a17d0f3eb0d29155f8f524968c13b4858c82e74..40091a0ef07b43230c6d7391f8b4a1045d1175b4 100644 (file)
@@ -188,6 +188,27 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
     return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
 }
 
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
+    const int * v, const int * u, const float * d8_0, const float & d8_1) {
+
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
+        int sumi = 0;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_0/2; ++i) {
+            // SIMD dot product of quantized values
+            sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+        }
+
+        sumf += d8_0[i0/(QI8_0/2)]*sumi;
+    }
+
+    return d8_1*sumf;
+}
+
 #define VDR_Q2_K_Q8_1_MMVQ 1
 #define VDR_Q2_K_Q8_1_MMQ  4