option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
-option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
-#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
-#define FP16_MMA_AVAILABLE
-#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
-
-#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
-#define FP16_MMA_AVAILABLE
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
-
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
}
-// Any FP16 tensor core instructions are available for ggml code.
-static bool fp16_mma_available(const int cc) {
-#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
- return false;
-#else
- if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
- GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
- GGML_CUDA_CC_IS_MTHREADS(cc)) {
- return true;
- } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
-#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
- return true;
-#else
- return false;
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
- } else {
- return false;
- }
-#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
-}
-
// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile.cuh"
+#include "fattn-wmma-f16.cuh"
// kq_stride == number of KQ rows to process per iteration
// kq_nbatch == number of K columns to load in parallel for KQ calculation
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
-#ifdef FP16_MMA_AVAILABLE
+#ifdef GGML_USE_WMMA_FATTN
NO_DEVICE_CODE;
return;
-#endif // FP16_MMA_AVAILABLE
+#endif // GGML_USE_WMMA_FATTN
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
#include "fattn-common.cuh"
#include "fattn-wmma-f16.cuh"
-#ifdef FP16_MMA_AVAILABLE
+#ifdef GGML_USE_WMMA_FATTN
#if !defined(GGML_USE_HIP)
#include <mma.h>
-#ifdef GGML_USE_MUSA
+#if defined(GGML_USE_MUSA)
namespace wmma = mtmusa::wmma;
#else // GGML_USE_MUSA
namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA
-#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
+#elif defined(GGML_USE_HIP)
#include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma;
#endif // !defined(GGML_USE_HIP)
-#endif // FP16_MMA_AVAILABLE
+#endif // GGML_USE_WMMA_FATTN
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
}
constexpr int get_max_power_of_2(int x) {
#include "common.cuh"
+#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+#define GGML_USE_WMMA_FATTN
+#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#define GGML_USE_WMMA_FATTN
+#elif defined(CDNA)
+#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
+#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#if defined(RDNA3)
+#define GGML_USE_WMMA_FATTN
+#endif // defined(RDNA3)
+#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#define GGML_USE_WMMA_FATTN
+#elif defined(RDNA4)
+#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
+#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
+// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
+static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
+#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+ return false;
+#else
+ if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
+ GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
+ return true;
+ } else if (GGML_CUDA_CC_IS_CDNA(cc)){
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+ return true;
+#else
+ return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+ } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+ return true;
+#else
+ return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+ } else {
+ return false;
+ }
+#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+}
+
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
- if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
+ if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
return BEST_FATTN_KERNEL_NONE;
}
break;
}
// For large batch sizes, use the WMMA kernel if possible:
- if (fp16_mma_available(cc)) {
+ if (ggml_cuda_should_use_wmma_fattn(cc)) {
return BEST_FATTN_KERNEL_WMMA_F16;
}
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#include <rocwmma/rocwmma-version.hpp>
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
-if (GGML_HIP_ROCWMMA_FATTN)
- CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
- if (NOT ${FOUND_ROCWMMA})
- message(FATAL_ERROR "rocwmma has not been found")
- endif()
-endif()
if (${hip_VERSION} VERSION_LESS 6.1)
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
endif()
-if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
- add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
-endif()
-
if (GGML_HIP_EXPORT_METRICS)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
endif()