]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
HIP: implement FlashAttention via rocWMMA for CDNA and RDNA3+ (llama/12032)
authorDavid Huang <redacted>
Mon, 3 Mar 2025 21:10:54 +0000 (05:10 +0800)
committerGeorgi Gerganov <redacted>
Sat, 8 Mar 2025 13:13:01 +0000 (15:13 +0200)
Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check
Adds rocWMMA support to fattn-wmma-f16

ggml/CMakeLists.txt
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-hip/CMakeLists.txt

index 835bf16e55ac34983a911b226cca2b38a3b2ff7e..92a62ea6fb5d5e1c39dfcac0179ca5c6b2be5ac0 100644 (file)
@@ -162,6 +162,7 @@ set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balan
 option(GGML_HIP                             "ggml: use HIP"                                   OFF)
 option(GGML_HIP_GRAPHS                      "ggml: use HIP graph, experimental, slow"         OFF)
 option(GGML_HIP_NO_VMM                      "ggml: do not try to use HIP VMM"                 ON)
+option(GGML_HIP_ROCWMMA_FATTN               "ggml: enable rocWMMA for FlashAttention"         OFF)
 option(GGML_HIP_UMA                         "ggml: use HIP unified memory architecture"       OFF)
 option(GGML_VULKAN                          "ggml: use Vulkan"                                OFF)
 option(GGML_VULKAN_CHECK_RESULTS            "ggml: run Vulkan op checks"                      OFF)
index adf0d3ecb566cc78bfd29726b1cc929ed0a66bf9..1832314ec133be218ff2c048f34b949762a99de4 100644 (file)
@@ -62,6 +62,7 @@
 #define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
 #define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
 
+#define GGML_CUDA_CC_IS_AMD(cc)   (cc >= GGML_CUDA_CC_OFFSET_AMD)
 #define GGML_CUDA_CC_IS_RDNA(cc)  (cc >= GGML_CUDA_CC_RDNA1)
 #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
 #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
@@ -196,6 +197,10 @@ typedef float2 dfloat2;
 #define FP16_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
+#define FP16_MMA_AVAILABLE
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
+
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define NEW_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) {
 
 // Any FP16 tensor core instructions are available for ggml code.
 static bool fp16_mma_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
+    return false;
+#else
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
+        GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fp16_mma_hardware_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
+    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
+        GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
index 7b9566fb4be3268b89e4bde50c10c42c7c491841..46de14093545c64f0ec50b6e951e93565dea4e31 100644 (file)
@@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
         const int shift = k_KQ & (QI8_1/2);
 
         const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
             sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
         } else
 #endif // FP16_AVAILABLE
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+            sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
         }
     }
 
@@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         const int shift = k_KQ & (QI8_1/2);
 
         const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
             const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
             sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
         } else
@@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            const float sumid4d8   =  __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
-            const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+            const float sumid4d8   =  __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
+            const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
 
             sum += (T) (sumid4d8 + m4s8scaled);
         }
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
         v |= (vh << 18) & 0x00100000; // 2 -> 20
         v |= (vh << 25) & 0x10000000; // 3 -> 28
 
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
             sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
         } else
 #endif // FP16_AVAILABLE
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+            sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
         }
     }
 
@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib    = k_KQ /  QI8_1;
@@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         v |= (vh << 18) & 0x00100000; // 2 -> 20
         v |= (vh << 25) & 0x10000000; // 3 -> 28
 
-        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+        const int u = Q_q8[k_KQ_0/warp_size];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
 
@@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
-            const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
             const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
             sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
         } else
@@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
 
-            const float sumid5d8   =  __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
-            const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+            const float sumid5d8   =  __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
+            const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
 
             sum += (T) (sumid5d8 + m5s8scaled);
         }
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const int ib  = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
         T Q_d;
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
-            Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
+            Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
         } else {
             const float2 * Q_ds = (const float2 *) Q_ds_v;
-            Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
+            Q_d = Q_ds[k_KQ_0/warp_size].x;
         }
 
-        sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
+        sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
     }
 
     return sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
 
     const half2 * K_h2 = (const half2 *) K_c;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
 
@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
         half2 sum2 = make_half2(0.0f, 0.0f);
 
 #pragma unroll
-        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
             const int k_KQ = k_KQ_0 + threadIdx.x;
 
             const half2 K_ik = K_h2[k_KQ];
-            sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+            sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
         }
 
         return __low2half(sum2) + __high2half(sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     float sum = 0.0f;
 
 #pragma unroll
-    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
         const int k_KQ = k_KQ_0 + threadIdx.x;
 
         const half2 K_ik = K_h2[k_KQ];
-        sum +=  __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
-        sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
+        sum +=  __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
+        sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
     }
 
     return sum;
@@ -698,6 +704,8 @@ void launch_fattn(
 
     GGML_ASSERT(Q->ne[3] == 1);
 
+    const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
+
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -750,7 +758,7 @@ void launch_fattn(
     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
-    const dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const dim3 block_dim(warp_size, nwarps, 1);
     dim3 blocks_num;
     if (parallel_blocks == 0) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    GGML_ASSERT(block_dim.x % warp_size == 0);
+    GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
         (const char *) Q->data,
         K_data,
index 8828652fb5e7f1c7437650c27adafacae332e917..622cf28576d2997361b05a7bd6fd475b25f6fadb 100644 (file)
@@ -7,14 +7,19 @@
 #include "fattn-wmma-f16.cuh"
 
 #ifdef FP16_MMA_AVAILABLE
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 #include <mma.h>
+namespace wmma = nvcuda::wmma;
+#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
+#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
+#include <rocwmma/rocwmma.hpp>
+namespace wmma = rocwmma;
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 #endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
 template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -60,6 +65,8 @@ static __global__ void flash_attn_ext_f16(
 
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
     const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
 
@@ -68,11 +75,11 @@ static __global__ void flash_attn_ext_f16(
     constexpr int frag_m = ncols == 8 ? 32 : 16;
     constexpr int frag_n = ncols == 8 ?  8 : 16;
     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
+    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
+    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
+    typedef wmma::fragment<wmma::matrix_b,    frag_m, frag_n, 16, half, wmma::col_major> frag_b;
+    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
+    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
 
     constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
     constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -132,9 +139,9 @@ static __global__ void flash_attn_ext_f16(
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
 #pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < D/2; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+            if (i0 + warp_size > D/2 && i >= D/2) {
                 break;
             }
             VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
@@ -146,9 +153,9 @@ static __global__ void flash_attn_ext_f16(
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
 #pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
+            if (i0 + warp_size > D && i >= D) {
                 break;
             }
             KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
@@ -162,7 +169,7 @@ static __global__ void flash_attn_ext_f16(
     for (int i0 = 0; i0 < D; i0 += 16) {
 #pragma unroll
         for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
         }
     }
 
@@ -176,20 +183,20 @@ static __global__ void flash_attn_ext_f16(
             frag_c_KQ KQ_c[ncols/frag_n];
 #pragma unroll
             for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+                wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
             }
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
                 frag_a_K K_a;
-                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+                wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                    wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
                 }
             }
 #pragma unroll
             for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+                wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
             }
         }
 
@@ -202,27 +209,27 @@ static __global__ void flash_attn_ext_f16(
             const int j = j0 + threadIdx.y;
 
             if (std::is_same<KQ_acc_t, float>::value) {
-                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+                float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+                    KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
 
                     if (use_logit_softcap) {
-                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                        KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
                     }
                 }
 
                 float KQ_max_new = KQ_max_f[j0/nwarps];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
-                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                    KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
                 }
-                KQ_max_new = warp_reduce_max(KQ_max_new);
+                KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
 
                 const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
                 KQ_max_scale_f[j0/nwarps] = expf(diff);
@@ -233,48 +240,48 @@ static __global__ void flash_attn_ext_f16(
 
                 float KQ_rowsum_add = 0.0f;
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
-                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/warp_size] = expf(diff);
                     if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                        KQ_f_tmp[k0/warp_size] = 0.0f;
                     }
-                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
-                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                    KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
                 }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+                KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
 
                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
                 KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
             } else {
-                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+                    KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
 
                     if (use_logit_softcap) {
                         // There is no dedicated tangens hyperbolicus function for half2.
-                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
-                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
-                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+                        KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
 
-                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                        KQ2_tmp[k0/warp_size] *= logit_softcap_2;
                     }
                 }
 
                 half2 KQ_max_new = KQ_max_h2[j0/nwarps];
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
-                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                    KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
                 }
-                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
                 const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
                 KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
                 const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
@@ -283,17 +290,17 @@ static __global__ void flash_attn_ext_f16(
 
                 half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
 #pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
-                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/warp_size] = h2exp(diff);
                     const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
-                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
-                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                    *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/warp_size];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
                 }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+                KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
 
                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
                 KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
@@ -308,7 +315,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-                nvcuda::wmma::load_matrix_sync(
+                wmma::load_matrix_sync(
                     KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
                     KQ + j0*(kqar*kqs_padded) + k,
                     kqar*kqs_padded);
@@ -320,7 +327,7 @@ static __global__ void flash_attn_ext_f16(
         for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
 #pragma unroll
             for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+                wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
             }
 
 #pragma unroll
@@ -328,10 +335,10 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
 
                 frag_a_V v_a;
-                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+                wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                    wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
                 }
             }
         }
@@ -343,10 +350,10 @@ static __global__ void flash_attn_ext_f16(
         for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
 #pragma unroll
             for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync(
+                wmma::store_matrix_sync(
                     KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
                     VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
-                    D_padded, nvcuda::wmma::mem_col_major);
+                    D_padded, wmma::mem_col_major);
             }
         }
 
@@ -364,9 +371,9 @@ static __global__ void flash_attn_ext_f16(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < D/2; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
-                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                if (i0 + warp_size > D/2 && i >= D/2) {
                     break;
                 }
 
@@ -398,9 +405,9 @@ static __global__ void flash_attn_ext_f16(
         }
 
 #pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
+            if (i0 + warp_size > D && i >= D) {
                 break;
             }
             float dst_val = VKQ[j_VKQ*D_padded + i];
@@ -425,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
 }
 
 constexpr int get_max_power_of_2(int x) {
@@ -515,6 +522,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
     const ggml_tensor * Q   = dst->src[0];
 
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+    const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
 
     if (prec != GGML_PREC_DEFAULT) {
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
@@ -571,7 +579,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         return;
     }
 
-    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+    if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
         constexpr int cols_per_block = 8;
         switch (Q->ne[0]) {
             case 64:
@@ -592,6 +601,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         }
         return;
     }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 16;
index b1becccb4de72f432fbd7690be0e3c6171562298..24f973056aa9a78a3cf1347e64adba2bb678ec00 100644 (file)
@@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    // On AMD the tile kernels perform poorly, use the vec kernel instead:
     if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+        if (fp16_mma_available(cc)) {
+            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+            return;
+        }
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
+        // On AMD the tile kernels perform poorly, use the vec kernel instead:
         if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
@@ -291,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const int gqa_ratio = Q->ne[2] / K->ne[2];
     const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
         K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;
index 4a0384dd47654ad3e507640ca423fedd7218898c..e3762649fd27574001ec1e5bb3f51cbdc739c2c8 100644 (file)
@@ -39,6 +39,12 @@ endif()
 find_package(hip     REQUIRED)
 find_package(hipblas REQUIRED)
 find_package(rocblas REQUIRED)
+if (GGML_HIP_ROCWMMA_FATTN)
+    CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
+    if (NOT ${FOUND_ROCWMMA})
+        message(FATAL_ERROR "rocwmma has not been found")
+    endif()
+endif()
 
 if (${hip_VERSION} VERSION_LESS 5.5)
     message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
@@ -107,6 +113,10 @@ if (GGML_HIP_NO_VMM)
     add_compile_definitions(GGML_HIP_NO_VMM)
 endif()
 
+if (GGML_HIP_ROCWMMA_FATTN)
+    add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
+endif()
+
 if (NOT GGML_CUDA_FA)
     add_compile_definitions(GGML_CUDA_NO_FA)
 endif()