]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: use tensor cores for MMQ (llama/7676)
authorJohannes Gäßler <redacted>
Mon, 10 Jun 2024 09:45:13 +0000 (11:45 +0200)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
* CUDA: int8 tensor cores for MMQ (legacy quants)

* fix out-of-bounds writes

* __builtin_assume -> GGML_CUDA_ASSUME

* fix writeback returning too early

src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-wmma-f16.cuh
src/ggml-cuda/mma.cuh [new file with mode: 0644]
src/ggml-cuda/mmq.cuh

index 90a0a81ead789679a148811acbfbf40201b4e932..7f4764d60e854a0c2b63ed28f8059820b290bccc 100644 (file)
 #define CC_PASCAL     600
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
 #define CC_VOLTA      700
+#define CC_TURING     750
 #define CC_AMPERE     800
 #define CC_OFFSET_AMD 1000000
 #define CC_RDNA1      (CC_OFFSET_AMD + 1010)
@@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
 #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
 #endif // defined(GGML_USE_HIPBLAS)
 
-#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#define FP16_AVAILABLE
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
 
-#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#define FP16_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#define INT8_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
 
 static bool fast_fp16_available(const int cc) {
     return cc >= CC_PASCAL && cc != 610;
@@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
     return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
 }
 
+static bool int8_mma_available(const int cc) {
+    return cc < CC_OFFSET_AMD && cc >= CC_TURING;
+}
+
 [[noreturn]]
 static __device__ void no_device_code(
     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 }
 
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
 
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #pragma unroll
@@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
 }
 
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
 
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
     return __float2half(fmaxf(__half2float(a), __half2float(b)));
index c00f8606a5c850302e156bb462fddb5c113fee18..37b3b99323b20f276586696647ed6d27ec95ad33 100644 (file)
@@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
 
         const int sumi = __dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
 
         const int sumi = __dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
 
         const int sumi = __dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
 
         const int sumi = __dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         const half2 * Q_h2 = (const half2 *) Q_v;
 
@@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
     const int q0 = x[ib].qs[iqs];
     const int q  = ((q0 >> (4*shift)) & 0x0F) - 8;
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return ((half) d)*((half) q);
     }
@@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
     const int   q0 = x[ib].qs[iqs];
     const int   q  = ((q0 >> (4*shift)) & 0x0F);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return __low2half(dm)*((half) q) + __high2half(dm);
     }
@@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
     const int qh  = ((qh0 >> idq) << 4) & 0x10;
     const int q   = (ql | qh) - 16;
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return ((half) d)*((half) q);
     }
@@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
     const int   qh  = ((qh0 >> idq) << 4) & 0x10;
     const int   q   = (ql | qh);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return __low2half(dm)*((half) q) + __high2half(dm);
     }
@@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
     const T   d = x[ib].d;
     const int q = x[ib].qs[iqs];
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return ((half) d)*((half) q);
     }
index cb11d7212ca28ce94d3c160c432beb32477a24c3..c6c35134d4db50cb68f1ca4a376d4f99291b6137 100644 (file)
@@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
index 9e1aa2c6b688520c1e8d9df843116998bf00900b..02a4ad072bbd6f70d8a4b47b7dbeaf674dcb790a 100644 (file)
@@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
index 59cd30d7837c90bd41add5e00296e24380d36b2d..ae232224260972fb23b0b2d4f8f52d693df5faab 100644 (file)
@@ -1,9 +1,9 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-#if FP16_MMA_AVAILABLE
+#ifdef FP16_MMA_AVAILABLE
 #include <mma.h>
-#endif
+#endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
 template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if FP16_MMA_AVAILABLE
+#ifdef FP16_MMA_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
diff --git a/src/ggml-cuda/mma.cuh b/src/ggml-cuda/mma.cuh
new file mode 100644 (file)
index 0000000..71e8e34
--- /dev/null
@@ -0,0 +1,95 @@
+#include "common.cuh"
+
+struct mma_int_A_I16K8 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+};
+
+struct mma_int_B_J8K8 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 8;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = l * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+};
+
+struct mma_int_C_I16J8 {
+    static constexpr int I  = 16;
+    static constexpr int J  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_j(const int l) {
+        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+        asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+#else
+        // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+};
index 3ccae8a0c36fa2e6ee164a59d43a6d91e5bdb1c3..62111f376ec815c6cd80ff01dc4f8e6ee6d75ccc 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "common.cuh"
 #include "vecdotq.cuh"
+#include "mma.cuh"
 
 #include <climits>
 #include <cstdint>
@@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)(
 typedef void (*vec_dot_mmq_t)(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
 
 struct block_q8_1_mmq {
     half2  ds[4];
@@ -141,15 +143,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const half2 * y_ds  = (const half2 *) y;
+    const float * x_df = (const float *) x_dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
-                (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
+                (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
                 y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const float * x_df = (const float *) x_dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    mma_A A;
+    float dA[mma_C::ne/2];
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+    for (int l = 0; l < mma_A::ne; ++l) {
+        const int i     = i0 + mma_A::get_i(l);
+        const int k     = k0 + mma_A::get_k(l) % QI4_0;
+        const int shift =   4*(mma_A::get_k(l) / QI4_0);
+
+        A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+    }
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+    }
+
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        mma_C C;
+        mma_B B;
+        half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int j =    j0 + mma_B::get_j(l);
+            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+        }
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+        }
+
+        C.mma_K8(A, B);
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
+        }
+    }
+}
+
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -215,7 +281,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
@@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    mma_A A;
+    half2 dmA[mma_C::ne/2];
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+    for (int l = 0; l < mma_A::ne; ++l) {
+        const int i     = i0 + mma_A::get_i(l);
+        const int k     = k0 + mma_A::get_k(l) % QI4_0;
+        const int shift =   4*(mma_A::get_k(l) / QI4_0);
+
+        A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
+    }
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+    }
+
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        mma_C C;
+        mma_B B;
+        half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int j =    j0 + mma_B::get_j(l);
+            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+        }
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+        }
+
+        C.mma_K8(A, B);
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
+            sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+        }
+    }
+}
+
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -308,7 +438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
@@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const float * x_df = (const float *) x_dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    mma_A A;
+    float dA[mma_C::ne/2];
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+    for (int l = 0; l < mma_A::ne; ++l) {
+        const int i     =    i0 + mma_A::get_i(l);
+        const int k     = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
+
+        A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+    }
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
+    }
+
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        mma_C C;
+        mma_B B;
+        float dB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int j =    j0 + mma_B::get_j(l);
+            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+        }
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+        }
+
+        C.mma_K8(A, B);
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
+        }
+    }
+}
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
@@ -400,7 +592,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
@@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    mma_A A;
+    half2 dmA[mma_C::ne/2];
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+    for (int l = 0; l < mma_A::ne; ++l) {
+        const int i     =    i0 + mma_A::get_i(l);
+        const int k     = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
+
+        A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
+    }
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
+    }
+
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        mma_C C;
+        mma_B B;
+        half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int j =    j0 + mma_B::get_j(l);
+            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+
+            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+        }
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+        }
+
+        C.mma_K8(A, B);
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
+            sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+        }
+    }
+}
+
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -475,7 +730,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
@@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const float * x_df = (const float *) x_dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    mma_A A;
+    float dA[mma_C::ne/2];
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+#pragma unroll
+    for (int l = 0; l < mma_A::ne; ++l) {
+        const int i = i0 + mma_A::get_i(l);
+        const int k = k0 + mma_A::get_k(l);
+
+        A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
+    }
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
+    }
+
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        mma_C C;
+        mma_B B;
+        float dB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int j = j0 + mma_B::get_j(l);
+            const int k = k0 + mma_B::get_k(l);
+
+            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+        }
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
+        }
+
+        C.mma_K8(A, B);
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
+        }
+    }
+}
+
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
     }
 }
 
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
+
+        if (j >= ne1) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
+
+            if (need_check && i >= ne0) {
+                continue;
+            }
+
+            dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
+}
+
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+    typedef mma_int_C_I16J8 mma_C;
+
+    const int i0 = threadIdx.y*mma_C::I;
+    static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
+
+            if (j >= ne1) {
+                continue;
+            }
+
+            const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
+
+            if (need_check && i >= ne0) {
+                continue;
+            }
+
+            dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
+        }
+    }
+}
+
 // -------------------------------------------------------------------------------------------------------------------------------------
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
@@ -998,35 +1367,65 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
     static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
     static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
     static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
     static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
     static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1034,6 +1433,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
     static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1041,6 +1441,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
     static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1048,6 +1449,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
     static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1055,6 +1457,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
     static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1062,6 +1465,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
     static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
 };
 
 static int mmq_need_sum(const ggml_type type_x) {
@@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q(
     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
+    constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
 
     constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
 
@@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q(
 
     const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
 
-    float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
     for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
 
@@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q(
         }
     }
 
-#pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
-
-        if (j >= ne1) {
-            return;
-        }
-
-#pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
-
-            if (need_check && i >= ne0) {
-                continue;
-            }
-
-            dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
-        }
-    }
+    write_back(sum, dst, ne0, ne1);
 }
 
 struct mmq_args {
@@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
             launch_mul_mat_q<type,   8, 4>(args, stream);
             break;
         case  16:
-            launch_mul_mat_q<type,  16, 8>(args, stream);
+            launch_mul_mat_q<type,  16, 4>(args, stream);
             break;
         case  24:
-            launch_mul_mat_q<type,  24, 8>(args, stream);
+            launch_mul_mat_q<type,  24, 4>(args, stream);
             break;
         case  32:
             launch_mul_mat_q<type,  32, 8>(args, stream);