]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
HIP: Add support for RDNA4 targets (#12372)
authorSlobodan Josic <redacted>
Wed, 26 Mar 2025 22:46:30 +0000 (23:46 +0100)
committerGitHub <redacted>
Wed, 26 Mar 2025 22:46:30 +0000 (23:46 +0100)
docs/build.md
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/mmvq.cu
ggml/src/ggml-cuda/vendors/hip.h

index aa1db9a04713037e269caa8d0d694e68c647ca17..9c1314a29431b44c96107633e384134966646451 100644 (file)
@@ -191,7 +191,7 @@ The following compilation options are also available to tweak performance:
 
 | Option                        | Legal values           | Default | Description                                                                                                                                                                                                                                                                             |
 |-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| GGML_CUDA_FORCE_MMQ           | Boolean                | false   | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower.                       |
+| GGML_CUDA_FORCE_MMQ           | Boolean                | false   | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower.                       |
 | GGML_CUDA_FORCE_CUBLAS        | Boolean                | false   | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models                                                                                                                                                                                       |
 | GGML_CUDA_F16                 | Boolean                | false   | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs.                                                           |
 | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer       | 128     | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial.                                                                         |
index 954ff5f16924bf0e0652673eb8fcb99973a19528..f8c55a2b869fc9a57aa46a878506368427160407 100644 (file)
@@ -52,7 +52,7 @@
 #define GGML_CUDA_CC_IS_NVIDIA(cc)   (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
 
 // AMD
-// GCN/CNDA, wave size is 64
+// GCN/CDNA, wave size is 64
 #define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
 #define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue
 #define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 0x906)  // MI50/Radeon VII, minimum for dp4a
 #define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x910)  // MI210, minimum acc register renameing
 #define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x942)  // MI300
 
-// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
+// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
 #define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
 #define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
 #define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA4      (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
 
 #define GGML_CUDA_CC_IS_AMD(cc)   (cc >= GGML_CUDA_CC_OFFSET_AMD)
 #define GGML_CUDA_CC_IS_RDNA(cc)  (cc >= GGML_CUDA_CC_RDNA1)
 #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
 #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
 #define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
 #define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
 
@@ -209,9 +211,9 @@ typedef float2 dfloat2;
 #define FP16_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
-#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
 #define FP16_MMA_AVAILABLE
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define NEW_MMA_AVAILABLE
@@ -244,14 +246,14 @@ static bool fp16_mma_available(const int cc) {
     return false;
 #else
     return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
-        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fp16_mma_hardware_available(const int cc) {
     return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
-        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
@@ -409,7 +411,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
-#elif defined(RDNA3)
+#elif defined(RDNA3) || defined(RDNA4)
     c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
 #elif defined(RDNA1) || defined(__gfx900__)
     int tmp1;
index 6dd5dcb85e15cfc81ac1e8f3e4a42813d44f8c93..3bb472ffbfdf289c7c069cb107d22b86a1039083 100644 (file)
@@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
 
-        if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
             const float alpha = 1.0f;
             const float beta = 0.0f;
             CUBLAS_CHECK(
@@ -1759,7 +1759,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         beta  = &beta_f32;
     }
 
-    if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
+    int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
         cu_compute_type = CUBLAS_COMPUTE_32F;
         alpha = &alpha_f32;
         beta  = &beta_f32;
@@ -1836,7 +1838,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     }
 #endif
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+    if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
     }
index 2c19485d51a92619b6db518689e12f45fe8fb9ee..b36b43d5417baf099ea2a91c319837faa180cf2d 100644 (file)
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
-    return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
index ee01154254e3d272cfecf8401c66114a0d322044..f136c41955b194ea5aec7f7a4e6cb65ba0619b2a 100644 (file)
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
 
 template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
     __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
 #else
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
     __launch_bounds__(WARP_SIZE*nwarps, 1)
index a7d518a574ddc0bbf632e39fc65d91614c47201b..45ea30f62df080b89d1de430816a5d02ba016c5e 100644 (file)
@@ -54,7 +54,7 @@ enum mmvq_parameter_table_id {
 };
 
 static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
-#if defined(RDNA2) || defined(RDNA3)
+#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
     return MMVQ_PARAMETERS_RDNA2;
 #elif defined(GCN) || defined(CDNA)
     return MMVQ_PARAMETERS_GCN;
@@ -64,7 +64,7 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
 }
 
 static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
-    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
+    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
         return MMVQ_PARAMETERS_RDNA2;
     }
     if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
index a4c717a321cfb325b7276968e07138bb1859b92a..3983ce5b423c0804e620125dcfa8ee00948c4fec 100644 (file)
 #define CDNA
 #endif
 
+#if defined(__GFX12__)
+#define RDNA4
+#endif
+
 #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
     defined(__gfx1150__) || defined(__gfx1151__)
 #define RDNA3