]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
HIP: WMMA-MMQ kernels for RDNA 4 (llama/17156)
authorJiacheng (Jason) Chen <redacted>
Mon, 24 Nov 2025 19:00:10 +0000 (14:00 -0500)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:07 +0000 (17:53 +0200)
* first commit naive test to enable mmq for RDNA4

* adding appropriate WMMA instructions

* git rebase on top of master: fixing the correctness of the mat mul operations, updating layout mappings for RDNA4

* clean up merge conflicts

* add comments and code clean up

* PR clean up, addressed comments

* enable MMQ fallback on RDNA4

* addressed comments: add guards in load generic, separate wmma branch for use_mmq function

* Revert build-xcframework.sh

* Formating: remove trailing whitespace

* revert CMake files

* clean up after rebase: remove duplicated change, revert cmake files

* clean up after rebase: revert changes from build-xcframework.sh

* clean up: remove extra space line in mma.cuh

* Revert "clean up: remove extra space line in mma.cuh"

This reverts commit b39ed57c4529906466bd0bc7c2a86e08fc2f8bee.

ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh

index c3c4b7799650104fc80605c7141052c745ea2e81..caa08b360b566a71d2b900e618923cf02459ec5b 100644 (file)
@@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
         static constexpr int I  = I_;
         static constexpr int J  = J_;
 
-#if defined(GGML_USE_HIP)
-#if defined(RDNA4)
-        static constexpr int ne = I * J / 32;
-        T x[ne] = {0};
-
-        static constexpr __device__ bool supported() {
-            if (I == 16 && J == 16) return true;
-            return false;
-        }
-
-        static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 16 && J == 16) {
-                return 8 * (threadIdx.x / 16) + l;
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
-        }
-
-        static __device__ __forceinline__ int get_j(const int l) {
-            if constexpr (I == 16 && J == 16) {
-                return threadIdx.x % 16;
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
-        }
-#else
+#if defined(AMD_MFMA_AVAILABLE)
         static constexpr int ne = I * J / 64;
         T x[ne] = {0};
 
@@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
                 return -1;
             }
         }
-#endif // defined(RDNA4)
 #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
         static constexpr int ne = I * J / 32;
         T x[ne] = {0};
@@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
                 return -1;
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+        static constexpr int ne = I * J / 32;
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 16) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 16 && J == 16) {
+                return 8 * (threadIdx.x / 16) + l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 16 && J == 16) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+#endif
 #else
         static constexpr int ne = I * J / 32;
         T x[ne] = {0};
@@ -437,7 +437,20 @@ namespace ggml_cuda_mma {
             xi[0] = xs[0];
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-        ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+        if constexpr (I == 16 && J == 4) {
+            int64_t * xi = (int64_t *) t.x;
+            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
+            xi[0] = xs[0];
+        }else if constexpr (I == 16 && J == 8) {
+            int64_t * xi = (int64_t *) t.x;
+            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
+            xi[0] = xs[0];
+
+            const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
+            xi[1] = xs1[0];
+        }else{
+            NO_DEVICE_CODE;
+        }
 #else
 #pragma unroll
         for (int l = 0; l < t.ne; ++l) {
@@ -772,6 +785,36 @@ namespace ggml_cuda_mma {
                                                       acc[0],
                                                       0, 0, 0);
 #endif // defined(CDNA3)
+
+#elif defined(AMD_WMMA_AVAILABLE)
+        using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
+        int32x2_t * a_vec = (int32x2_t *) A.x;
+        int32x2_t * b_vec = (int32x2_t *) B.x;
+
+        using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
+        int32x8_t * acc = (int32x8_t *) D.x;
+
+#if defined(RDNA4)
+
+        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+            true,
+            a_vec[0],
+            true,
+            b_vec[0],
+            acc[0],
+            true
+        );
+
+        acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+            true,
+            a_vec[1],
+            true,
+            b_vec[1],
+            acc[0],
+            true
+        );
+#endif // defined(RDNA4)
+
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
@@ -798,6 +841,7 @@ namespace ggml_cuda_mma {
                                                      acc[0],
                                                      0, 0, 0);
 #endif // defined(CDNA3)
+
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
@@ -842,4 +886,31 @@ namespace ggml_cuda_mma {
         mma(D16[1], A16[1], B);
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     }
+
+static __device__ __forceinline__ void mma(
+            tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
+#if defined(AMD_WMMA_AVAILABLE)
+    using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
+    int32x2_t * a_vec = (int32x2_t *) A.x;
+    int32x2_t * b_vec = (int32x2_t *) B.x;
+
+    using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
+    int32x8_t * acc = (int32x8_t *) D.x;
+
+    acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+        true,
+        a_vec[0],
+        true,
+        b_vec[0],
+        acc[0],
+        false
+    );
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif
+    }
 }
+
index a2c8760abea93cdb1e80b098b11bed1d61d53df3..03ceba874d8dbe4552f0718bb1f61e6c5d816a03 100644 (file)
@@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    if (amd_wmma_available(cc)) {
+        if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+            return true;
+        }
+    }
+
+    return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
index 2e133b6bda8841477098b73804c7a19b7ad44975..99760d56c7294a03c224536ac1ec105ccb6f9fa9 100644 (file)
@@ -92,7 +92,7 @@ struct tile_x_sizes {
 };
 
 static int get_mmq_x_max_host(const int cc) {
-    return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
+    return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
         GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
 #ifdef GGML_CUDA_FORCE_MMQ
             128                     : 64;
@@ -102,7 +102,7 @@ static int get_mmq_x_max_host(const int cc) {
 }
 
 static constexpr __device__ int get_mmq_x_max_device() {
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     return 128;
 #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
@@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #endif // defined(GGML_USE_HIP)
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
 static int get_mmq_y_host(const int cc) {
@@ -231,7 +231,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
-    if (amd_mfma_available(cc)) {
+    if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
         return mmq_x >= 128 ? 32 : 16;
     } else if (turing_mma_available(cc) && mmq_x >= 48) {
         return 16;
@@ -240,7 +240,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
     }
 }
 
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
     return mmq_x >= 128 ? 32 : 16;
 }
@@ -265,7 +265,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
 #endif // (GGML_USE_HIP)
 
 static constexpr __device__ int mmq_get_nwarps_device() {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     return 8;
 #else
     return 256/ggml_cuda_get_physical_warp_size();
@@ -279,14 +279,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
     constexpr int nrows = warp_size / threads_per_row;
@@ -305,7 +305,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b2(bxi->qs, kqsx);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 #else
@@ -327,11 +327,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -382,14 +382,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
     constexpr int nrows = warp_size / threads_per_row;
@@ -408,12 +408,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b4(bxi->qs, kqsx);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
@@ -430,11 +430,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 #else
         x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -485,14 +485,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
     constexpr int nrows = warp_size / threads_per_row;
@@ -527,13 +527,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
@@ -550,11 +550,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -563,14 +563,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
     constexpr int nrows = warp_size / threads_per_row;
@@ -603,13 +603,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
@@ -626,11 +626,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 #else
         x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -639,14 +639,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
     constexpr int threads_per_row = 32;
@@ -665,13 +665,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
@@ -688,11 +688,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0                 + kbxd] = bxi->d;
 #else
         x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -701,14 +701,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
     constexpr int nrows = warp_size / threads_per_row;
@@ -730,13 +730,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
         const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0]        = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]        = v.x;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
@@ -753,11 +753,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_1                 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -796,7 +796,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
 template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     typedef tile<16,  8, int> tile_A;
     typedef tile<16,  8, int> tile_B;
     typedef tile<16, 16, int> tile_C;
@@ -927,7 +927,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             }
         }
     }
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
 template <int mmq_x, int mmq_y>
@@ -965,7 +965,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
 template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     typedef tile<16,  8, int> tile_A;
     typedef tile<16,  8, int> tile_B;
     typedef tile<16, 16, int> tile_C;
@@ -1087,7 +1087,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
             }
         }
     }
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
 // Used for Q3_K, IQ2_S, and IQ2_XS
@@ -1170,6 +1170,54 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
                 tile_C C;
                 mma(C, A[n], B[0]);
 
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
+                }
+            }
+        }
+    }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+    typedef tile<16,  4, int> tile_A;
+    typedef tile<16,  4, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B;
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            const int j = j0 + tile_C::get_j(0);
+            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+                mma(C, A[n], B);
+
 #pragma unroll
                 for (int l = 0; l < tile_C::ne; ++l) {
                     const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -1257,21 +1305,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #else
     GGML_UNUSED_VARS(x, y, sum, k00);
     NO_DEVICE_CODE;
-#endif // AMD_MFMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
 }
 
 template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
     constexpr int nwarps = mmq_get_nwarps_device();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
     constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
@@ -1295,11 +1343,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
             const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -1310,11 +1358,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
 #else
         x_dm[i*(MMQ_TILE_NE_K + 1)   + kqsx] = x_dm_ik;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -1438,6 +1486,72 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
                 tile_C Cd;
                 mma(Cd, A[n], B[0]);
 
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+                    const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
+                    float tmp = Cd.x[l]*dm.x;
+                    if (k01 >= MMQ_TILE_NE_K * 3/4) {
+                        tmp -= Cm.x[l]*dm.y;
+                    }
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
+                }
+            }
+        }
+    }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+
+    typedef tile<16,  4, int> tile_A;
+    typedef tile<16,  4, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B;
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            const int j = j0 + tile_C::get_j(0);
+            const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
+            const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
+                                              : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
+                                                             : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
+
+            tile_C Cm;
+            if (k01 >= MMQ_TILE_NE_K * 3/4) {
+                tile_A A1;
+                A1.x[0] = 0x01010101;
+                A1.x[1] = 0x01010101;
+                mma(Cm, A1, B);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C Cd;
+                mma(Cd, A[n], B);
+
 #pragma unroll
                 for (int l = 0; l < tile_C::ne; ++l) {
                     const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -1574,7 +1688,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #else
     GGML_UNUSED_VARS(x, y, sum, k00);
     NO_DEVICE_CODE;
-#endif // AMD_MFMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
 }
 
 template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
@@ -1582,7 +1696,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
@@ -1618,11 +1732,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
             const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
     }
 
@@ -1649,7 +1763,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         const int8_t * sc8 = (const int8_t *) &sc;
         const float d = bxi->d;
 
@@ -1659,10 +1773,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         }
 #else
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
-#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
         int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
@@ -1675,7 +1789,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         x_df[i] = bxi->d;
     }
-#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
 }
 
 template <int mmq_x, int mmq_y>
@@ -1728,7 +1842,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
@@ -1736,7 +1850,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
     constexpr int nrows = warp_size / threads_per_row;
@@ -1753,19 +1867,19 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
         const int qs0 = get_int_b4(bxi->qs, txi);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     constexpr int rows_per_warp = warp_size / 2;
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         // Need if on AMD instead of % because warp_size == 64
         // This causes double work and throughput loss (MI300X)
         // H100 loses about 100 t/s with 'if' condition over '%'
@@ -1774,7 +1888,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 #else
         int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
         {
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             if (need_check) {
                 i = min(i, i_max);
             }
@@ -1829,7 +1943,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
     }
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
 template <int mmq_x, int mmq_y>
@@ -1872,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
 #else
@@ -1908,16 +2022,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     constexpr int rows_per_warp = warp_size / 2;
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@@ -1930,7 +2044,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 #else
         int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
         {
-#endif // defined(AMD_MFMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             if (need_check) {
                 i = min(i, i_max);
             }
@@ -1986,7 +2100,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
     }
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
 template <int mmq_x, int mmq_y>
@@ -2029,7 +2143,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
     int   * x_sc = (int   *) (x_df + MMQ_TILE_NE_K/QI6_K);
@@ -2038,7 +2152,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2065,13 +2179,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
         const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
 #pragma unroll
@@ -2084,11 +2198,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q6_K]           = bxi->d;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int rows_per_warp = warp_size / 4;
@@ -2102,11 +2216,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
 #else
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2190,6 +2304,56 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
                 tile_C C;
                 mma(C, A[n], B[0]);
 
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+                    const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
+                }
+            }
+        }
+    }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+    typedef tile<16,  4, int> tile_A;
+    typedef tile<16,  4, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B;
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            const int j = j0 + tile_C::get_j(0);
+            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+                mma(C, A[n], B);
+
 #pragma unroll
                 for (int l = 0; l < tile_C::ne; ++l) {
                     const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -2303,7 +2467,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #else
     GGML_UNUSED_VARS(x, y, sum, k00);
     NO_DEVICE_CODE;
-#endif // AMD_MFMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
 }
 
 template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
@@ -2311,14 +2475,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2340,13 +2504,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
         const int k0 = kbx * (2 * QI4_NL) + kqsx;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0]      = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]      = v.x;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
@@ -2363,11 +2527,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = __half2float(bxi->d);
 #else
         x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2376,14 +2540,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2414,22 +2578,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
             const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4;
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2438,14 +2602,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2472,24 +2636,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2498,15 +2662,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
-
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
     const int kqsx = threadIdx.x % threads_per_row;
@@ -2539,24 +2702,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2565,14 +2728,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2601,22 +2764,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = (ls*d + d/2)/2;
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = (ls*d + d/2)/2;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2625,14 +2788,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2668,22 +2831,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
         const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = ls*d;
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = ls*d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2692,14 +2855,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2727,23 +2890,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid0 = (grid >> 0) & 0x0F0F0F0F;
             const int grid1 = (grid >> 4) & 0x0F0F0F0F;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
         const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
         const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_ds[i*MMQ_MMA_TILE_X_K_Q8_1     + kqsx] = make_half2(d1q, d1q*delta);
 #else
         x_ds[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2752,14 +2915,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2779,13 +2942,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
         const int k0 = 8 * (kqsx / 4) + kqsx % 4;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 
     constexpr int rows_per_warp = warp_size / 8;
@@ -2804,11 +2967,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + threadIdx.x % 8] = d * (ls - 32);
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     }
 }
 
@@ -2848,7 +3011,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int nwarps = mmq_get_nwarps_device();
 
-#if defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     constexpr int tileC_IJ = mmq_get_granularity_device(0);
     typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
     constexpr int rows_per_warp = granularity;
@@ -2859,11 +3022,11 @@ static __device__ __forceinline__ void mmq_write_back_mma(
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
     const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
-#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
 #else
     GGML_UNUSED(nwarps);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -3063,13 +3226,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
     int * tile_y = data_mul_mat_q + mmq_x;
     int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
     constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
 #else
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
     constexpr int blocks_per_iter = MMQ_ITER_K / qk;