]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
musa: refine compute capability (llama/12493)
authorR0CKSTAR <redacted>
Sat, 22 Mar 2025 09:11:37 +0000 (17:11 +0800)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
* musa: refine compute capability

Signed-off-by: Xiaodong Ye <redacted>
* Address review comments

Signed-off-by: Xiaodong Ye <redacted>
---------

Signed-off-by: Xiaodong Ye <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh

index e78205e5d53afd4a47ef77a49d8ad242217ef9bc..6b5cd32a4a541256bfeb91e134bb8f6fe7f5ea13 100644 (file)
 #define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
 #define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
-#define GGML_CUDA_CC_PASCAL       600
-#define GGML_CUDA_CC_DP4A         610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define GGML_CUDA_CC_VOLTA        700
-#define GGML_CUDA_CC_TURING       750
-#define GGML_CUDA_CC_AMPERE       800
-#define GGML_CUDA_CC_ADA_LOVELACE 890
-#define GGML_CUDA_CC_OFFSET_AMD   0x1000000
-
+#define GGML_CUDA_CC_PASCAL          600
+#define GGML_CUDA_CC_DP4A            610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA           700
+#define GGML_CUDA_CC_TURING          750
+#define GGML_CUDA_CC_AMPERE          800
+#define GGML_CUDA_CC_ADA_LOVELACE    890
+#define GGML_CUDA_CC_OFFSET_AMD      0x1000000
+#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
+#define GGML_CUDA_CC_IS_NVIDIA(cc)   (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
+
+// AMD
 // GCN/CNDA, wave size is 64
 #define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
 #define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue
 #define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
 #define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
 
-#define GGML_CUDA_CC_QY1        210
-#define GGML_CUDA_CC_QY2        220
+// Moore Threads
+#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
+
+#define GGML_CUDA_CC_QY1  (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
+#define GGML_CUDA_CC_QY2  (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
+#define GGML_CUDA_CC_NG   (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
+
+#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
+#define GGML_CUDA_CC_IS_QY1(cc)      (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
+#define GGML_CUDA_CC_IS_QY2(cc)      (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
+#define GGML_CUDA_CC_IS_NG(cc)       (cc >= GGML_CUDA_CC_NG)
 
 #ifdef __CUDA_ARCH_LIST__
 constexpr bool ggml_cuda_has_arch_impl(int) {
@@ -209,21 +221,21 @@ typedef float2 dfloat2;
 #define CP_ASYNC_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
 #define FLASH_ATTN_AVAILABLE
-#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
 
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
 }
 
 static bool fast_fp16_available(const int cc) {
-    return fp16_available(cc) && cc != 610;
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fast_fp16_hardware_available(const int cc) {
-    return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
+    return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
 }
 
 // Any FP16 tensor core instructions are available for ggml code.
@@ -231,20 +243,20 @@ static bool fp16_mma_available(const int cc) {
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
     return false;
 #else
-    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
-        GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
 static bool fp16_mma_hardware_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
-        GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA ||
+        GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
 static bool new_mma_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
 
 static bool cp_async_available(const int cc) {
index 973541893ec21e8df9c58d5229908ac2fcda0aee..8edc12649aa63573c3cca0068daa672fa0f0db31 100644 (file)
@@ -253,7 +253,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
+    if (GGML_CUDA_CC_IS_AMD(cc)) {
 #if defined(GGML_HIP_ROCWMMA_FATTN)
         if (fp16_mma_available(cc)) {
             ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
index b783310ef7ba74596a4543b5c78692603881a19c..10d461b773a6ac9feaa7e6c798a8d6008fb3f4b8 100644 (file)
@@ -264,9 +264,9 @@ static ggml_cuda_device_info ggml_cuda_init() {
 #elif defined(GGML_USE_MUSA)
         // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
         info.devices[id].warp_size = 32;
-        // TODO: refine the .cc to reflect MUSA's actual CC capabilities
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
-        info.devices[id].cc = 100*prop.major + 10*prop.minor;
+        info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
+        info.devices[id].cc += prop.minor * 0x10;
         GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
                         id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 #else
@@ -1188,11 +1188,11 @@ static void ggml_cuda_op_mul_mat_cublas(
     // ldc == nrows of the matrix that cuBLAS writes into
     int64_t ldc = id == ctx.device ? ne0 : row_diff;
 
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
+    const int cc = ggml_cuda_info().devices[id].cc;
 
     const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
 
-    if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
+    if (((cc >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
@@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
 
-        if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
+        if (GGML_CUDA_CC_IS_CDNA(cc)) {
             const float alpha = 1.0f;
             const float beta = 0.0f;
             CUBLAS_CHECK(
index 10f2ebb1cb43df09ea81d38548d44cb7ad10f8b9..510c1e9b2aa9e04922afd168ceafb740f43e85ec 100644 (file)
@@ -28,7 +28,7 @@ void ggml_cuda_op_mul_mat_q(
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
     const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
-        cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
+        GGML_CUDA_CC_IS_NVIDIA(cc) && src1_ncols == ne11;
     const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
 
     switch (src0->type) {
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
     return true;
 #endif //GGML_CUDA_FORCE_MMQ
 
-    if (cc < GGML_CUDA_CC_OFFSET_AMD) {
+    if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
         return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
index f2aca1f2014e1dd4d846054bb5d05a0e9a8ffd13..4ea8b8d4b12904b1977608114e7a49f0cfaca456 100644 (file)
@@ -90,7 +90,7 @@ struct tile_x_sizes {
 
 static int get_mmq_x_max_host(const int cc) {
     return new_mma_available(cc) ? 128 :
-        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
+        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc) ?
 #ifdef GGML_CUDA_FORCE_MMQ
             128                     : 64;
 #else
@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
 }
 
 static int get_mmq_y_host(const int cc) {
-    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
-        (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
+    return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
+        ((ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) ? 128 : 64);
 }
 
 static constexpr __device__ int get_mmq_y_device() {
@@ -2772,14 +2772,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
     const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
     if (!shmem_limit_raised[id]) {
         CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         shmem_limit_raised[id] = true;
     }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
 
     const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
     const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
@@ -2832,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
-    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
+    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc);
 
     int mmq_x_best  = 0;
     int nparts_best = INT_MAX;