]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
MUSA: support ARM64 and enable dp4a .etc (llama/11843)
authorBodhi <redacted>
Fri, 21 Feb 2025 07:46:23 +0000 (15:46 +0800)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
* MUSA:  support ARM64 and enable __dp4a .etc

* fix cross entropy loss op for musa

* update

* add cc info log for musa

* add comment for the MUSA .cc calculation block

---------

Co-authored-by: Bodhi Hu <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/cross-entropy-loss.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-impl.h
ggml/src/ggml-musa/CMakeLists.txt

index 4a92d35f9f40ca1ab191ee5a336cc632168c4423..7e99838c09261ec7069dcbe3a1deba690ab75735 100644 (file)
@@ -411,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 
 #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
-#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     return __dp4a(a, b, c);
-#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     const int8_t * a8 = (const int8_t *) &a;
     const int8_t * b8 = (const int8_t *) &b;
     return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
-#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
index 27599a2b03839919fcdc87026babaa57678d64b3..0ce4afbb222bd22c59b6e2710adda9ae5b89d3f6 100644 (file)
@@ -123,13 +123,13 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
 
     if (nbytes_shared <= smpbo) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
         cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
     } else {
         cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -175,13 +175,13 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
     const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
 
     if (nbytes_shared <= smpbo) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
             CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
         cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
     } else {
         cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
index 093ad70991b5a208f15db32c100ac91e1babfa04..cc772801e03bdc3711ae26eea4ca29850b202e17 100644 (file)
@@ -261,6 +261,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
         GGML_LOG_INFO("  Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
                       id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
                       device_vmm ? "yes" : "no", prop.warpSize);
+#elif defined(GGML_USE_MUSA)
+        // TODO: refine the .cc to reflect MUSA's actual CC capabilities
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor;
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
+                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 #else
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
@@ -1782,9 +1788,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         }
     }
 #else
-#ifdef GGML_USE_MUSA
-    GGML_ASSERT(false);
-#else // !GGML_USE_MUSA
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
@@ -1827,7 +1830,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     }
-#endif // GGML_USE_MUSA
 #endif
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
index eab017889c9198adb2f3ed5b7c661431149b561c..1fbcbd0456e992d79fc5fe95841c7e2cf2bfb192 100644 (file)
@@ -16,7 +16,7 @@
 #include <arm_sve.h>
 #endif // __ARM_FEATURE_SVE
 
-#if defined(__ARM_NEON) && !defined(__CUDACC__)
+#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
 // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
 //
 //   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
index 2f555416e62cf3ab7df050e03438ea5e452641b9..1bfc07c5d717a8dfe74829ee68a6359ce2de999d 100644 (file)
@@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND)
 
     set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
     foreach(SOURCE ${GGML_SOURCES_MUSA})
-        set(COMPILE_FLAGS "-x musa -mtgpu")
+        set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
         foreach(ARCH ${MUSA_ARCHITECTURES})
             set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
         endforeach()