]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: add CDNA3 MFMA support for flash attention MMA kernel (llama/19806)
authorJayant Lohia <redacted>
Fri, 27 Feb 2026 18:37:26 +0000 (00:07 +0530)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* CUDA: add CDNA3 MFMA support for flash attention MMA kernel

Add MI300X (gfx942) MFMA tensor core flash attention using
v_mfma_f32_16x16x16_f16 (FP16 in, FP32 accumulate).

- Add FATTN_WARP_SIZE=64 for CDNA wavefront64
- Add CDNA config for head sizes 64, 80, 96, 112, 128
- Add FP16 MFMA intrinsic path in mma.cuh
- Add manual V transpose load for MFMA register layout
- Route CDNA to MMA for prompt processing, VEC for token generation
- Fix Q loading and combine stride granularity for non-power-of-2 heads

Benchmarks (Qwen2.5-1.5B Q4_K_M, MI300X):
  pp512  +7%,  pp1024 +13%,  pp2048 +23%,  pp4096 +39%
  tg128  -10% (FA overhead, VEC used for both)

All 2480 flash attention tests pass.

Ref: https://github.com/ggml-org/llama.cpp/issues/17917

* address review: replace FATTN_WARP_SIZE with constexpr, improve dispatch

- Replace #define FATTN_WARP_SIZE with constexpr int warp_size =
  ggml_cuda_get_physical_warp_size() in each device function
- Use ne[1]*gqa_ratio threshold for MMA vs tile dispatch. Benchmarked
  crossover on MI300X @ d32768 with power-of-2 GQA models:
    hsk=64  (Llama 1B, gqa=4): MMA wins at eff >= 128 (+11%)
    hsk=128 (Llama 3B, gqa=4): MMA wins at eff >= 128 (+4%)
  Unified threshold: eff_nq >= 128 for all head sizes.
- Remove VEC fallback; small batches fall through to tile kernel

* Update ggml/src/ggml-cuda/fattn.cu

* use ggml_cuda_info().devices warp_size instead of hardcoded check

---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/mma.cuh

index 0b8ef90794c7cc63db0b8300e7bab17d16814602..beb7e32e4fc35a08c4d10244664821508ac19e9b 100644 (file)
@@ -111,6 +111,44 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 }
 
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
+    // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64,  8, 128, 2, 128,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 16, 128, 2,  64,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 32, 128, 2,  64,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  32,  32,  32, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80,  8, 128, 2, 128,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 16, 128, 2,  64,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 32, 128, 2,  64,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 64, 256, 2,  64,  40,  40,  40, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96,  8, 128, 2, 128,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 16, 128, 2,  64,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 32, 128, 2,  64,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 64, 256, 2,  64,  48,  48,  48, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112,  8, 128, 2, 128,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2,  64,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2,  64,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2,  64,  56,  56,  56, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128,  8, 128, 2, 128,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2,  64,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2,  64,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2,  64,  64,  64,  64, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8,  64, 4,  64, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16,  64, 4,  32, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2,  32, 128, 128, 128, 1, true);
+
+    // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
+    // compile-time static_asserts even though the kernel guard prevents runtime execution.
+    // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
+    return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
+}
+
 static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
     if (ampere_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -118,6 +156,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
     if (turing_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
     }
+    if (amd_mfma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
+    }
     if (amd_wmma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
     }
@@ -130,6 +171,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
     return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 #elif defined(TURING_MMA_AVAILABLE)
     return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+#elif defined(AMD_MFMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
 #elif defined(VOLTA_MMA_AVAILABLE)
     return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
 #elif defined(AMD_WMMA_AVAILABLE)
@@ -205,15 +248,15 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
 }
 
 static constexpr __device__ int get_cols_per_thread() {
-#if defined(AMD_WMMA_AVAILABLE)
-    return 1; // RDNA has a single column.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    return 1; // AMD has a single column per thread.
 #else
     return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 }
 
 static __host__ int get_cols_per_warp(const int cc) {
-    if (turing_mma_available(cc) || amd_wmma_available(cc)) {
+    if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
         return 16;
     } else {
         // Volta
@@ -241,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
 template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
 static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
     // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
     if constexpr (use_cp_async) {
@@ -252,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
 
         auto load = [&] __device__ (auto n) {
-            const int stride_k = WARP_SIZE >> n;
-            const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
+            const int stride_k = warp_size >> n;
+            const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
             const int k0_stop  =                             chunks_per_row - chunks_per_row % (1*stride_k);
-            const int stride_i = WARP_SIZE / stride_k;
+            const int stride_i = warp_size / stride_k;
 
             if (k0_start == k0_stop) {
                 return;
@@ -263,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
             for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
-                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
                     break;
@@ -271,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
                 }
@@ -287,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
     } else {
         // TODO use ggml_cuda_memcpy_1
         auto load = [&] __device__ (const int n) {
-            const int stride_k = WARP_SIZE >> n;
-            const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
+            const int stride_k = warp_size >> n;
+            const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
             const int k0_stop  =                             D2 - D2 % (1*stride_k);
-            const int stride_i = WARP_SIZE / stride_k;
+            const int stride_i = warp_size / stride_k;
 
             if (k0_start == k0_stop) {
                 return;
@@ -298,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
             for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
-                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
                     break;
@@ -306,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
                 }
@@ -324,18 +368,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
 static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
         const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
         const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     if constexpr (use_cp_async) {
-        static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
+        static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
         static_assert(!oob_check, "OOB check incompatible with cp_async");
         constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
-        constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
+        constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
         constexpr int stride_j = nwarps * cols_per_warp;
 
         const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
 
 #pragma unroll
         for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
-            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
             const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
             if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
@@ -357,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
             }
         }
-    } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
-        constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+    } else if constexpr (nbatch_fa < 2*warp_size) {
+        constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
         constexpr int stride_j = nwarps * cols_per_warp;
 #pragma unroll
         for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
-            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
             const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
             if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
                 break;
             }
 
-            const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
+            const int i = threadIdx.x % (warp_size/cols_per_warp);
 
             ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
         }
@@ -390,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
+            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
                 const int i = i0 + 2*threadIdx.x;
 
                 ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
@@ -428,7 +473,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int jt,
         const int kb0,
         const int k_VKQ_sup) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
+    constexpr int  warp_size       = ggml_cuda_get_physical_warp_size();
     constexpr int  ncols           = ncols1 * ncols2;
     constexpr int  cols_per_warp   = T_B_KQ::I;
     constexpr int  cols_per_thread = get_cols_per_thread();
@@ -447,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     const int k_VKQ_0 = kb0 * nbatch_fa;
 #if defined(TURING_MMA_AVAILABLE)
     T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
     T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 #else // Volta
     T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
@@ -500,13 +546,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
                     } else {
                         // Wide version of KQ_C is column-major
-#if defined(AMD_WMMA_AVAILABLE)
-                        // RDNA matrix C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                        // AMD matrix C is column-major.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
 #else
                         // swap A and B for CUDA.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     }
                 }
             }
@@ -526,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
                     } else {
                         // Wide version of KQ_C is column-major
-#if defined(AMD_WMMA_AVAILABLE)
-                        // RDNA matrix C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                        // AMD matrix C is column-major.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
 #else
                         // swap A and B for CUDA.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     }
                 }
             }
@@ -585,12 +631,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     constexpr int KQ_idx = 0;
 #else
                     // Turing + Volta:
                     const int KQ_idx = l % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
@@ -601,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
             for (int offset = 16; offset >= 4; offset >>= 1) {
-                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
             }
         }
 
@@ -611,12 +657,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     constexpr int KQ_idx = 0;
 #else
                     // Turing + Volta:
                     const int KQ_idx = l % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
                     KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
                 } else {
@@ -649,12 +695,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     constexpr int KQ_idx = 0;
 #else
                     // Turing + Volta:
                     const int KQ_idx = (l/2) % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
@@ -666,6 +712,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             // Values per KQ column are spread across 4 threads:
             constexpr int offset_first = 2;
             constexpr int offset_last  = 1;
+#elif defined(AMD_MFMA_AVAILABLE)
+            // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
+            constexpr int offset_first = 32;
+            constexpr int offset_last  = 16;
 #elif defined(AMD_WMMA_AVAILABLE)
             // Values per KQ column are spread across 2 threads:
             constexpr int offset_first = 16;
@@ -677,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #endif // defined(TURING_MMA_AVAILABLE)
 #pragma unroll
             for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
-                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
             }
         }
 
@@ -687,12 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     constexpr int KQ_idx = 0;
 #else
                     // Turing + Volta:
                     const int KQ_idx = (l/2) % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
                     KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
                 } else {
@@ -739,7 +789,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 }
             }
         }
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
         const half2 KQ_max_scale_h2 = make_half2(
             KQ_max_scale[0], KQ_max_scale[0]);
 #pragma unroll
@@ -818,7 +868,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
         const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
 
-#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
         constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
 #pragma unroll
         for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
@@ -830,24 +880,38 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
 #if defined(LDMATRIX_TRANS_AVAILABLE)
                 load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+#elif defined(AMD_MFMA_AVAILABLE)
+                // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
+                // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
+                // Load with transposed addressing: 4 strided half loads.
+                {
+                    const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
+                    const half * xs0_h = (const half *) xs0;
+                    const int stride_h = stride_tile_V * 2; // stride in half units
+                    half * A_h = (half *) A.x;
+#pragma unroll
+                    for (int l = 0; l < 4; ++l) {
+                        A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
+                    }
+                }
 #else
                 // TODO: Try to transpose tile_V when loading gmem to smem.
                 // Use mma to transpose T_A_VKQ for RDNA.
                 T_A_VKQ A_trans;
                 load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
                 mma(A, A_trans, A_identity);
-#endif // defined(TURING_MMA_AVAILABLE)
+#endif // defined(LDMATRIX_TRANS_AVAILABLE)
                 if constexpr (T_B_KQ::I == 8) {
                     mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
                 } else {
                     // Wide version of VKQ_C is column-major.
-#if defined(AMD_WMMA_AVAILABLE)
-                    // RDNA matrix C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    // AMD matrix C is column-major.
                     mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
 #else
                     // swap A and B for CUDA.
                     mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                 }
             }
         }
@@ -866,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
             }
         }
-#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 
         if constexpr (nstages <= 1) {
             __syncthreads(); // Only needed if tile_K == tile_V.
@@ -879,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         tile_Q, tile_K, tile_V, tile_mask,
         Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
 }
 
 #if defined(TURING_MMA_AVAILABLE)
@@ -899,7 +963,7 @@ template<> struct mma_tile_sizes<8> {
     using T_B_VKQ = tile< 8,  8, half2>; // column-major
     using T_C_VKQ = tile<16,  4, half2>; // row-major
 };
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 template<int ncols> struct mma_tile_sizes {
     using T_A_KQ  = tile<16,  8, half2>; // row-major
     using T_B_KQ  = tile<16,  8, half2>; // column-major
@@ -944,9 +1008,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int zt_gqa,
         const int kb0_start,
         const int kb0_stop) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int ncols = ncols1 * ncols2;
     using     T_A_KQ    = typename mma_tile_sizes<ncols>::T_A_KQ;
     using     T_B_KQ    = typename mma_tile_sizes<ncols>::T_B_KQ;
@@ -986,7 +1051,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     T_B_KQ    Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
 #if defined(TURING_MMA_AVAILABLE)
     T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
     T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 #else // Volta
     T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
@@ -1004,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     // The loading is done with decreasing granularity for D for better memory bandwidth.
     const half2 scale_h2 = make_half2(scale, scale);
 #pragma unroll
-    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-        const int k0_start  = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
+    for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+        const int k0_start  = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
         const int k0_stop   =                             DKQ/2 - (DKQ/2) % (1*stride_k);
-        const int stride_jc = WARP_SIZE / stride_k;
+        const int stride_jc = warp_size / stride_k;
 
         if (k0_start == k0_stop) {
             continue;
@@ -1015,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
         for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
-            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
             if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
                 break;
@@ -1027,7 +1092,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
                     tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
@@ -1035,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             } else {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
                 }
@@ -1127,6 +1192,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // The partial sums are spread across 8/4 threads.
         constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
         constexpr int offset_last  = cols_per_warp == 8 ?  4 : 1;
+#elif defined(AMD_MFMA_AVAILABLE)
+        // The partial sums are spread across 4 threads (wavefront64, 16 cols).
+        constexpr int offset_first = 32;
+        constexpr int offset_last  = 16;
 #elif defined(AMD_WMMA_AVAILABLE)
         // The partial sums are spread across 2 threads.
         constexpr int offset_first = 16;
@@ -1140,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
             for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
-                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
             }
         }
     }
@@ -1189,7 +1258,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                 }
             }
         }
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
         const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
 #pragma unroll
         for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
@@ -1249,7 +1318,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
         const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
         const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
         const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
         const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
@@ -1283,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // Warps with threadIdx.y % np != 0 must NOT return early.
         // All threads must return simultaneously to avoid race conditions with work on the next tile.
 
-        constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+        constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
 
-        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
         float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
         float2 meta[nmeta];
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
-            meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
+            meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
         }
 
         float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
@@ -1300,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset < WARP_SIZE) {
-                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+            if (offset < warp_size) {
+                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
             }
         }
 
@@ -1318,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset < WARP_SIZE) {
-                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+            if (offset < warp_size) {
+                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
             }
         }
 
@@ -1328,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // Write back combined meta data:
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
-            if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+            if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
                 // Combined KQ max scale + rowsum.
-                meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
+                meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
             }
         }
 
         // Combined KQ max + rowsum.
-        static_assert(cols_per_warp <= WARP_SIZE);
-        if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+        static_assert(cols_per_warp <= warp_size);
+        if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
             dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
-        if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+        if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
             dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
@@ -1388,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
 
 #pragma unroll
-            for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-                const int k0_start  = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
+            for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+                const int k0_start  = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
                 const int k0_stop   =                             nbatch_combine - nbatch_combine % (1*stride_k);
-                const int stride_jc = WARP_SIZE / stride_k;
+                const int stride_jc = warp_size / stride_k;
 
                 if (k0_start == k0_stop) {
                     continue;
@@ -1399,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
                 for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
-                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                     if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
                         break;
@@ -1417,7 +1486,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
 #pragma unroll
                     for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                        const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                        const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                         float2 dstk_val = make_float2(0.0f, 0.0f);
 #pragma unroll
@@ -1453,7 +1522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
         jt, kb0_start, kb0_stop);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
 }
 
 template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
@@ -1480,7 +1549,7 @@ static __global__ void flash_attn_ext_f16(
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1508,10 +1577,18 @@ static __global__ void flash_attn_ext_f16(
     }
 #endif // defined(AMD_WMMA_AVAILABLE)
 
+#if defined(AMD_MFMA_AVAILABLE)
+    if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // defined(AMD_MFMA_AVAILABLE)
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int ncols     = ncols1 * ncols2;
     constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int nthreads  = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
-    constexpr int nwarps    = nthreads / WARP_SIZE;
+    constexpr int nwarps    = nthreads / warp_size;
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
@@ -1624,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
 }
 
 template <int DKQ, int DV, int ncols1, int ncols2>
@@ -1644,7 +1721,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     const int  nstages        = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2, cc);
 
     const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
-    const int nwarps        = nthreads / WARP_SIZE;
+    const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
+    const int nwarps         = nthreads / warp_size_host;
 
     constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
 
@@ -1694,7 +1772,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     }
 
     launch_fattn<DV, ncols1, ncols2>
-        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
+        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
 }
 
 
index 721edd9994407d62b22e4cba939fbb3b6a1916ab..85c177f496fad2bb6ecbee64f06fbbc8ea759027 100644 (file)
@@ -440,6 +440,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         return BEST_FATTN_KERNEL_MMA_F16;
     }
 
+    // Use MFMA flash attention for CDNA (MI100+):
+    if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
+        const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
+        // MMA vs tile crossover benchmarked on MI300X @ d32768:
+        //   hsk=64  (gqa=4): MMA wins at eff >= 128 (+11%)
+        //   hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)
+        if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {
+            return BEST_FATTN_KERNEL_MMA_F16;
+        }
+        // Fall through to tile kernel for small effective batch sizes.
+    }
+
     // If there are no tensor cores available, use the generic tile kernel:
     if (can_use_vector_kernel) {
         if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
index dd45d6c78fd743213cce4933db2c723182614578..5d1dadd3e4f41f1d8bc8be9268848f5a92a54e40 100644 (file)
@@ -668,7 +668,7 @@ namespace ggml_cuda_mma {
 
         return ret;
     }
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
     template <int I, int J>
     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
         tile<I, J/2, half2> ret;
@@ -964,6 +964,34 @@ namespace ggml_cuda_mma {
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
 #endif // defined(RDNA4)
+#elif defined(AMD_MFMA_AVAILABLE)
+        // MFMA: FP16 input, FP32 accumulate, convert back to half2.
+        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+
+        // Convert existing half2 accumulator to float for MFMA:
+        floatx4_t acc_f32;
+        {
+            const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]);
+#pragma unroll
+            for (int i = 0; i < 4; ++i) {
+                acc_f32[i] = (float)acc_h[i];
+            }
+        }
+
+        const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
+        const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
+        acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
+
+        // Convert back to half2:
+        {
+            halfx4_t result_h;
+#pragma unroll
+            for (int i = 0; i < 4; ++i) {
+                result_h[i] = (_Float16)acc_f32[i];
+            }
+            reinterpret_cast<halfx4_t&>(D.x[0]) = result_h;
+        }
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;