]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
HIP: Disable ROCWMMA fattn on CDNA when compiled against ROCWMMA 2.0.0 (llama/16221)
authoruvos <redacted>
Wed, 1 Oct 2025 21:09:25 +0000 (23:09 +0200)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 04:57:25 +0000 (07:57 +0300)
* HIP: Disable ROCWMMA fatt on CDNA when compiled against ROCWMMA 2.0.0

rocwmma 2.0.0 includes a bug in the code fakeing fp16 accumulation on CDNA

* CUDA: Fix volta condition in ggml_cuda_should_use_wmma_fattn

CMakeLists.txt
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-tile.cu
src/ggml-cuda/fattn-wmma-f16.cu
src/ggml-cuda/fattn-wmma-f16.cuh
src/ggml-cuda/fattn.cu
src/ggml-cuda/vendors/hip.h
src/ggml-hip/CMakeLists.txt

index 56420587a95930aa0152c06fc507d420f0209e75..6ce52ffc6698bba1688b0f7d87d605c4713686c1 100644 (file)
@@ -209,7 +209,6 @@ option(GGML_HIP                             "ggml: use HIP"
 option(GGML_HIP_GRAPHS                      "ggml: use HIP graph, experimental, slow"         OFF)
 option(GGML_HIP_NO_VMM                      "ggml: do not try to use HIP VMM"                 ON)
 option(GGML_HIP_ROCWMMA_FATTN               "ggml: enable rocWMMA for FlashAttention"         OFF)
-option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12   "ggml: enable rocWMMA FlashAttention on GFX12"    OFF)
 option(GGML_HIP_MMQ_MFMA                    "ggml: enable MFMA MMA for CDNA in MMQ"           ON)
 option(GGML_HIP_EXPORT_METRICS              "ggml: enable kernel perf metrics output"         OFF)
 option(GGML_MUSA_GRAPHS                     "ggml: use MUSA graph, experimental, unstable"    OFF)
index c4246b65eb788aac6ba502c85a6b206c607445fb..d51abbeafa944fb966f864095c30634e29e1be82 100644 (file)
@@ -220,14 +220,6 @@ static const char * cu_get_error_str(CUresult err) {
 #define FAST_FP16_AVAILABLE
 #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 
-#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
-#define FP16_MMA_AVAILABLE
-#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
-
-#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
-#define FP16_MMA_AVAILABLE
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
-
 #if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
 #define AMD_MFMA_AVAILABLE
 #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
@@ -262,27 +254,6 @@ static bool fast_fp16_hardware_available(const int cc) {
         (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
 }
 
-// Any FP16 tensor core instructions are available for ggml code.
-static bool fp16_mma_available(const int cc) {
-#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
-    return false;
-#else
-    if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
-        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
-        GGML_CUDA_CC_IS_MTHREADS(cc)) {
-        return true;
-    } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
-#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
-        return true;
-#else
-        return false;
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
-    } else {
-        return false;
-    }
-#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
-}
-
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fp16_mma_hardware_available(const int cc) {
     return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
index 131a5099a3b07d1d502114abb75cef9aab5a8bfc..68de623d803499cbb602f4385105212b39e5496a 100644 (file)
@@ -1,6 +1,7 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 #include "fattn-tile.cuh"
+#include "fattn-wmma-f16.cuh"
 
 // kq_stride == number of KQ rows to process per iteration
 // kq_nbatch == number of K columns to load in parallel for KQ calculation
@@ -190,10 +191,10 @@ static __global__ void flash_attn_tile(
 #ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
-#ifdef FP16_MMA_AVAILABLE
+#ifdef GGML_USE_WMMA_FATTN
     NO_DEVICE_CODE;
     return;
-#endif // FP16_MMA_AVAILABLE
+#endif // GGML_USE_WMMA_FATTN
 
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
index 2219191fd91524ea5a638d0f4796e146adb269e2..6c90d6d52b3351b36433749614477ddd1aee3410 100644 (file)
@@ -6,19 +6,19 @@
 #include "fattn-common.cuh"
 #include "fattn-wmma-f16.cuh"
 
-#ifdef FP16_MMA_AVAILABLE
+#ifdef GGML_USE_WMMA_FATTN
 #if !defined(GGML_USE_HIP)
 #include <mma.h>
-#ifdef GGML_USE_MUSA
+#if defined(GGML_USE_MUSA)
 namespace wmma = mtmusa::wmma;
 #else // GGML_USE_MUSA
 namespace wmma = nvcuda::wmma;
 #endif // GGML_USE_MUSA
-#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
+#elif defined(GGML_USE_HIP)
 #include <rocwmma/rocwmma.hpp>
 namespace wmma = rocwmma;
 #endif // !defined(GGML_USE_HIP)
-#endif // FP16_MMA_AVAILABLE
+#endif // GGML_USE_WMMA_FATTN
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
 template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
 }
 
 constexpr int get_max_power_of_2(int x) {
index beeea95eb1d629be62560df18a3a35e435434b65..1848d088361850436a682d46fe357eed02b7dd58 100644 (file)
@@ -1,3 +1,49 @@
 #include "common.cuh"
 
+#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+#define GGML_USE_WMMA_FATTN
+#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#define GGML_USE_WMMA_FATTN
+#elif defined(CDNA)
+#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
+#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#if defined(RDNA3)
+#define GGML_USE_WMMA_FATTN
+#endif // defined(RDNA3)
+#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#define GGML_USE_WMMA_FATTN
+#elif defined(RDNA4)
+#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
+#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
+// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
+static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
+#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+    return false;
+#else
+    if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
+        GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
+        return true;
+    } else if (GGML_CUDA_CC_IS_CDNA(cc)){
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+        return true;
+#else
+        return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+    } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+        return true;
+#else
+        return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+    } else {
+        return false;
+    }
+#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+}
+
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 1cbd4f5bd6684bc5cdb272ba97bb07ae6a2a093d..d7736d36108a7b5e9a340bb7fbcd9f0a6630cd10 100644 (file)
@@ -222,7 +222,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (V->ne[0] != K->ne[0]) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
+            if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             break;
@@ -300,7 +300,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     }
 
     // For large batch sizes, use the WMMA kernel if possible:
-    if (fp16_mma_available(cc)) {
+    if (ggml_cuda_should_use_wmma_fattn(cc)) {
         return BEST_FATTN_KERNEL_WMMA_F16;
     }
 
index 37386afcd405b73338055435a5fd68d60fc27226..890c10364983b0969f06fadb5ba4a8779fcf279b 100644 (file)
@@ -6,6 +6,10 @@
 #include <hip/hip_fp16.h>
 #include <hip/hip_bf16.h>
 
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#include <rocwmma/rocwmma-version.hpp>
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
 #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_OP_N HIPBLAS_OP_N
index d327b90cceb25ab424918a58f52cc2252f758164..0e2b1847e09e239e37d982929b29babe94a6528b 100644 (file)
@@ -39,12 +39,6 @@ endif()
 find_package(hip     REQUIRED)
 find_package(hipblas REQUIRED)
 find_package(rocblas REQUIRED)
-if (GGML_HIP_ROCWMMA_FATTN)
-    CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
-    if (NOT ${FOUND_ROCWMMA})
-        message(FATAL_ERROR "rocwmma has not been found")
-    endif()
-endif()
 
 if (${hip_VERSION} VERSION_LESS 6.1)
     message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
@@ -117,10 +111,6 @@ if (NOT GGML_HIP_MMQ_MFMA)
     add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
 endif()
 
-if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
-    add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
-endif()
-
 if (GGML_HIP_EXPORT_METRICS)
     set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
 endif()