]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
musa: enable fp16 mma (all) and cublas on qy2 (#13842)
authorR0CKSTAR <redacted>
Thu, 26 Jun 2025 04:11:59 +0000 (12:11 +0800)
committerGitHub <redacted>
Thu, 26 Jun 2025 04:11:59 +0000 (12:11 +0800)
* musa: enable fp16 mma (all) and cublas on qy2

Signed-off-by: Xiaodong Ye <redacted>
* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Johannes Gäßler <redacted>
* Address review comments

Signed-off-by: Xiaodong Ye <redacted>
* Address review comments

Signed-off-by: Xiaodong Ye <redacted>
* musa: disable MUL_MAT_ID (q2_k × f32) due to precision issues

Signed-off-by: Xiaodong Ye <redacted>
---------

Signed-off-by: Xiaodong Ye <redacted>
Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-musa/mudnn.cuh

index f6127aeee425a1eb4c3e2f3da65ae3c9111bdc42..ea203550238254b6337fe22012b9b54796d92396 100644 (file)
 #define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
 
 // Moore Threads
-#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
-
-#define GGML_CUDA_CC_QY1  (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
-#define GGML_CUDA_CC_QY2  (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
-#define GGML_CUDA_CC_NG   (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
+#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
+#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
+#define GGML_CUDA_CC_NG  (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
 
 #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
 #define GGML_CUDA_CC_IS_QY1(cc)      (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
@@ -203,9 +201,9 @@ typedef float2 dfloat2;
 #define FAST_FP16_AVAILABLE
 #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
 #define FP16_MMA_AVAILABLE
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
 
 #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
 #define FP16_MMA_AVAILABLE
@@ -219,9 +217,9 @@ typedef float2 dfloat2;
 #define CP_ASYNC_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
 #define FLASH_ATTN_AVAILABLE
-#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
 
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
@@ -233,7 +231,8 @@ static bool fast_fp16_available(const int cc) {
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fast_fp16_hardware_available(const int cc) {
-    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
+        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
 }
 
 // Any FP16 tensor core instructions are available for ggml code.
@@ -242,7 +241,8 @@ static bool fp16_mma_available(const int cc) {
     return false;
 #else
     if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
-        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
+        GGML_CUDA_CC_IS_MTHREADS(cc)) {
         return true;
     } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
 #if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
@@ -259,7 +259,8 @@ static bool fp16_mma_available(const int cc) {
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fp16_mma_hardware_available(const int cc) {
     return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
-        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
+        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
 }
 
 static bool bf16_mma_hardware_available(const int cc) {
index c5668adb152b2cb42ea17cba63e0f002accd6989..f3b794c3644c8a5bb7343887a15eee1f13c6f034 100644 (file)
@@ -9,7 +9,11 @@
 #ifdef FP16_MMA_AVAILABLE
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 #include <mma.h>
+#ifdef GGML_USE_MUSA
+namespace wmma = mtmusa::wmma;
+#else // GGML_USE_MUSA
 namespace wmma = nvcuda::wmma;
+#endif // GGML_USE_MUSA
 #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
 #undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
 #include <rocwmma/rocwmma.hpp>
index b3e6833c396fdee87b4bb324ebf667228339a116..b30c13c62f25ce756265a726f54bf7c4f008b08e 100644 (file)
@@ -1227,9 +1227,12 @@ static void ggml_cuda_op_mul_mat_cublas(
 
     const int cc = ggml_cuda_info().devices[id].cc;
 
+    const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
+        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
+
     const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
 
-    if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+    if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
         ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
         if (src1->type != GGML_TYPE_BF16) {
             const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1257,7 +1260,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
         to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-    } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
+    } else if (fast_fp16_hardware_available(cc) && use_fp16) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
@@ -3061,9 +3064,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     return false;
                 }
 #ifdef GGML_USE_MUSA
-                if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
-                    !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
-                    return false;
+                const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
+                if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
+                    if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
+                            a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
+                        return false;
+                    }
+                    if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
+                            a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
+                        return false;
+                    }
                 }
 #endif // GGML_USE_MUSA
                 switch (a->type) {
@@ -3090,11 +3100,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     case GGML_TYPE_IQ4_NL:
                     case GGML_TYPE_IQ4_XS:
                     case GGML_TYPE_BF16:
-#ifdef GGML_USE_MUSA
-                        if (a->type == GGML_TYPE_Q3_K) {
-                            return false;
-                        }
-#endif // GGML_USE_MUSA
                         return true;
                     default:
                         return false;
index a63be5755c79ca00445f36f1921604995921adee..c30128561e810b169527eb939e662d8c95c4f912 100644 (file)
@@ -1,7 +1,7 @@
 #pragma once
 
-#include "../include/ggml.h"
-#include "../ggml-cuda/common.cuh"
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
 
 // Asynchronously copies data from src tensor to dst tensor using the provided context.
 // Returns a musaError_t indicating success or failure.