]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
HIP: enable mfma mmq on gfx908 and gfx90a for select datatypes and shapes (#14949)
authoruvos <redacted>
Wed, 30 Jul 2025 15:38:06 +0000 (17:38 +0200)
committerGitHub <redacted>
Wed, 30 Jul 2025 15:38:06 +0000 (17:38 +0200)
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh

index 4560d85c4cfd3d4b0dcddfc96b637d7cf1508cfe..7fb04d51b770f9dcaeff3b6fb97d07a347152e4a 100644 (file)
@@ -227,9 +227,9 @@ typedef float2 dfloat2;
 #define FP16_MMA_AVAILABLE
 #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
 
-#if defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
+#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
 #define AMD_MFMA_AVAILABLE
-#endif // defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA)
+#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
 
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define NEW_MMA_AVAILABLE
@@ -293,10 +293,9 @@ static bool fp32_mma_hardware_available(const int cc) {
     return GGML_CUDA_CC_IS_CDNA(cc);
 }
 
-// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
 static bool amd_mfma_available(const int cc) {
 #if !defined(GGML_HIP_NO_MMQ_MFMA)
-    return GGML_CUDA_CC_IS_CDNA3(cc);
+    return GGML_CUDA_CC_IS_CDNA(cc);
 #else
     return false;
 #endif //!defined(GGML_HIP_NO_MMQ_MFMA)
index e2fd0c1c254e19fe18bf8e86fd71966dc0209f67..d4954fbe69e11b7db243ee286e71f9b245e223c7 100644 (file)
@@ -109,8 +109,8 @@ void ggml_cuda_mul_mat_q(
     const int64_t s03 = src0->nb[3] / ts_src0;
     const int64_t s3  =  dst->nb[3] / ts_dst;
 
-    const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
-                            || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));
+    const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
+                            || GGML_CUDA_CC_IS_CDNA(cc);
 
     if (!ids) {
         const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -252,7 +252,7 @@ void ggml_cuda_op_mul_mat_q(
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
     const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
-                            || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
+                            || GGML_CUDA_CC_IS_CDNA(cc))
                             && src1_ncols == ne11;
     const mmq_args args = {
         src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
@@ -306,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    if (new_mma_available(cc) || amd_mfma_available(cc)) {
+    if (new_mma_available(cc)) {
         return true;
     }
 
@@ -322,5 +322,21 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
+    if (amd_mfma_available(cc)) {
+        // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
+        // performs better but is currently suffering from a crash on this architecture.
+        // TODO: Revisit when hipblaslt is fixed on CDNA3
+        if (GGML_CUDA_CC_IS_CDNA3(cc)) {
+            return true;
+        }
+        if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
+            return true;
+        }
+        if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
+            return true;
+        }
+        return false;
+    }
+
     return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
index d8650dd70b91515c8120b99ccd37fb617d956750..dd60529faa41d58dab99800ae0b19b6fc85fa047 100644 (file)
@@ -3096,8 +3096,8 @@ static __global__ void mul_mat_q(
     }
     __syncthreads();
 
-    // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
-#if (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+    // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
     {
         const int wt = blockIdx.z / nchannels_y;
         const int zt = blockIdx.z - wt*nchannels_y;