]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
feat: Support Moore Threads GPU (llama/8383)
authorR0CKSTAR <redacted>
Sat, 27 Jul 2024 23:41:25 +0000 (07:41 +0800)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 19:48:46 +0000 (22:48 +0300)
* Update doc for MUSA

Signed-off-by: Xiaodong Ye <redacted>
* Add GGML_MUSA in Makefile

Signed-off-by: Xiaodong Ye <redacted>
* Add GGML_MUSA in CMake

Signed-off-by: Xiaodong Ye <redacted>
* CUDA => MUSA

Signed-off-by: Xiaodong Ye <redacted>
* MUSA adds support for __vsubss4

Signed-off-by: Xiaodong Ye <redacted>
* Fix CI build failure

Signed-off-by: Xiaodong Ye <redacted>
---------

Signed-off-by: Xiaodong Ye <redacted>
ggml/CMakeLists.txt
ggml/include/ggml-cuda.h
ggml/src/CMakeLists.txt
ggml/src/ggml-common.h
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/common.cuh

index 1768a508bb9f68e053afde0d2183182f59643fa0..a5c2e96a86ca05c5410d394474901f218a7f39e7 100644 (file)
@@ -113,6 +113,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
 option(GGML_LLAMAFILE                       "ggml: use LLAMAFILE"                             OFF)
 
 option(GGML_CUDA                            "ggml: use CUDA"                                  OFF)
+option(GGML_MUSA                            "ggml: use MUSA"                                  OFF)
 option(GGML_CUDA_FORCE_DMMV                 "ggml: use dmmv instead of mmvq CUDA kernels"     OFF)
 option(GGML_CUDA_FORCE_MMQ                  "ggml: use mmq kernels instead of cuBLAS"         OFF)
 option(GGML_CUDA_FORCE_CUBLAS               "ggml: always use cuBLAS instead of mmq kernels"  OFF)
index d7903c666cebfc6d89a43b3d73031833b360cbf5..71bb6dcf07975e95949fed4b41c48cf024faa212 100644 (file)
@@ -6,6 +6,9 @@
 #ifdef GGML_USE_HIPBLAS
 #define GGML_CUDA_NAME "ROCm"
 #define GGML_CUBLAS_NAME "hipBLAS"
+#elif defined(GGML_USE_MUSA)
+#define GGML_CUDA_NAME "MUSA"
+#define GGML_CUBLAS_NAME "muBLAS"
 #else
 #define GGML_CUDA_NAME "CUDA"
 #define GGML_CUBLAS_NAME "cuBLAS"
index c6496c9211d70ad69b14f7ffd4520126a4de65b4..836496fb95de5bc7200f16f392a11ff9249942c0 100644 (file)
@@ -139,6 +139,17 @@ if (GGML_METAL)
         )
 endif()
 
+if (GGML_MUSA)
+    set(CMAKE_C_COMPILER clang)
+    set(CMAKE_C_EXTENSIONS OFF)
+    set(CMAKE_CXX_COMPILER clang++)
+    set(CMAKE_CXX_EXTENSIONS OFF)
+
+    set(GGML_CUDA ON)
+
+    list(APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
+endif()
+
 if (GGML_OPENMP)
     find_package(OpenMP)
     if (OpenMP_FOUND)
@@ -147,6 +158,11 @@ if (GGML_OPENMP)
         add_compile_definitions(GGML_USE_OPENMP)
 
         set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
+
+        if (GGML_MUSA)
+            set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp")
+            set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so")
+        endif()
     else()
         message(WARNING "OpenMP not found")
     endif()
@@ -249,7 +265,13 @@ endif()
 if (GGML_CUDA)
     cmake_minimum_required(VERSION 3.18)  # for CMAKE_CUDA_ARCHITECTURES
 
-    find_package(CUDAToolkit)
+    if (GGML_MUSA)
+        list(APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/")
+        find_package(MUSAToolkit)
+        set(CUDAToolkit_FOUND ${MUSAToolkit_FOUND})
+    else()
+        find_package(CUDAToolkit)
+    endif()
 
     if (CUDAToolkit_FOUND)
         message(STATUS "CUDA found")
@@ -268,7 +290,11 @@ if (GGML_CUDA)
         endif()
         message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
 
-        enable_language(CUDA)
+        if (GGML_MUSA)
+            set(CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE})
+        else()
+            enable_language(CUDA)
+        endif()
 
         file(GLOB   GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
         list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
@@ -332,21 +358,40 @@ if (GGML_CUDA)
             add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
         endif()
 
+        if (GGML_MUSA)
+            set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
+            foreach(SOURCE ${GGML_SOURCES_CUDA})
+                set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
+            endforeach()
+        endif()
+
         if (GGML_STATIC)
             if (WIN32)
                 # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
                 set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
             else ()
-                set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+                if (GGML_MUSA)
+                    set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
+                else()
+                    set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+                endif()
             endif()
         else()
-            set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
+            if (GGML_MUSA)
+                set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
+            else()
+                set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
+            endif()
         endif()
 
         if (GGML_CUDA_NO_VMM)
             # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
         else()
-            set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
+            if (GGML_MUSA)
+                set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
+            else()
+                set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
+            endif()
         endif()
     else()
         message(WARNING "CUDA not found")
@@ -857,8 +902,10 @@ function(get_flags CCID CCVER)
         set(C_FLAGS   -Wdouble-promotion)
         set(CXX_FLAGS -Wno-array-bounds)
 
-        if (CCVER VERSION_GREATER_EQUAL 7.1.0)
-            list(APPEND CXX_FLAGS -Wno-format-truncation)
+        if (NOT GGML_MUSA)
+            if (CCVER VERSION_GREATER_EQUAL 7.1.0)
+                list(APPEND CXX_FLAGS -Wno-format-truncation)
+            endif()
         endif()
         if (CCVER VERSION_GREATER_EQUAL 8.1.0)
             list(APPEND CXX_FLAGS -Wextra-semi)
@@ -1264,6 +1311,7 @@ endif()
 target_compile_definitions(ggml PUBLIC  ${GGML_CDEF_PUBLIC})
 target_include_directories(ggml PUBLIC ../include)
 target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
+target_link_directories(ggml PRIVATE ${GGML_EXTRA_LIBDIRS})
 target_compile_features   (ggml PRIVATE c_std_11) # don't bump
 
 target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})
index fafd5fa7ae000ddd2d5646ce69947332aa4c3385..e40057632fc5aadee29f84c5b7857b4c4d728c30 100644 (file)
@@ -19,7 +19,11 @@ typedef half2 ggml_half2;
 
 #define GGML_COMMON_DECL
 #elif defined(GGML_COMMON_DECL_CUDA)
+#if defined(GGML_COMMON_DECL_MUSA)
+#include <musa_fp16.h>
+#else
 #include <cuda_fp16.h>
+#endif
 #include <cstdint>
 
 typedef half  ggml_half;
@@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
 #define GGML_TABLE_END() };
 
 #define GGML_COMMON_IMPL
-#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
+#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
 #include <cstdint>
 
 #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
index 54ccf6bb1703c0f6a19797d0cc88d02d141f07bb..c73ae40d49da61d76939ae70f1a8e230c962c996 100644 (file)
@@ -167,7 +167,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
     for (int id = 0; id < info.device_count; ++id) {
         int device_vmm = 0;
 
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
         CUdevice device;
         CU_CHECK(cuDeviceGet(&device, id));
         CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -179,7 +179,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
             alloc_prop.location.id = id;
             CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
         }
-#endif // !defined(GGML_USE_HIPBLAS)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
         info.devices[id].vmm = !!device_vmm;
 
         cudaDeviceProp prop;
@@ -315,7 +315,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
 };
 
 // pool with virtual memory
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
 struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
     static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
 
@@ -409,14 +409,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
         GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
     }
 };
-#endif // !defined(GGML_USE_HIPBLAS)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
 
 std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
     if (ggml_cuda_info().devices[device].vmm) {
         return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
     }
-#endif
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
     return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
 }
 
@@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
 static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
     void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
 
-#if !defined(GGML_USE_HIPBLAS)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
     // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
     cudaMemcpy3DPeerParms p = {};
     p.dstDevice = dstDevice;
@@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
     GGML_UNUSED(dstDevice);
     GGML_UNUSED(srcDevice);
     return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
-#endif // !defined(GGML_USE_HIPBLAS)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
 }
 
 static void ggml_cuda_op_mul_mat(
@@ -1828,6 +1828,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         }
     }
 #else
+#ifdef GGML_USE_MUSA
+    GGML_ASSERT(false);
+#else // !GGML_USE_MUSA
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
@@ -1870,6 +1873,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     }
+#endif // GGML_USE_MUSA
 #endif
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@@ -3027,7 +3031,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
         return false;
     }
 
-#if CUDART_VERSION >= 11100
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
     cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
     if (err != cudaSuccess) {
         // clear the error
index eac026f478e5a35de683c1cf713c6a611cc1dce1..8c3c20b90ad668f7440e0da6b4ca040f40e8129c 100644 (file)
 #else
 #define GGML_COMMON_DECL_CUDA
 #define GGML_COMMON_IMPL_CUDA
+#if defined(GGML_USE_MUSA)
+#define GGML_COMMON_DECL_MUSA
+#define GGML_COMMON_IMPL_MUSA
+#endif
 #endif
 #include "ggml-common.h"
 
 #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
 #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
 #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+#elif defined(GGML_USE_MUSA)
+#include <musa_runtime.h>
+#include <musa.h>
+#include <mublas.h>
+#include <musa_fp16.h>
+// XXX: Keep the following order the same as hipBLAS
+// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
+// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
+#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N MUBLAS_OP_N
+#define CUBLAS_OP_T MUBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
+// #define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F  MUSA_R_16F
+#define CUDA_R_32F  MUSA_R_32F
+// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+// #define cublasComputeType_t mublasComputeType_t
+#define cublasCreate mublasCreate
+#define cublasDestroy mublasDestroy
+#define cublasGemmEx mublasGemmEx
+#define cublasGemmBatchedEx mublasGemmBatchedEx
+#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
+#define cublasHandle_t mublasHandle_t
+// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetMathMode mublasSetMathMode
+#define cublasSetStream mublasSetStream
+#define cublasSgemm mublasSgemm
+#define cublasStatus_t mublasStatus_t
+#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
+#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
+#define cudaDeviceProp musaDeviceProp
+#define cudaDeviceSynchronize musaDeviceSynchronize
+#define cudaError_t musaError_t
+#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags musaEventCreateWithFlags
+#define cudaEventDisableTiming musaEventDisableTiming
+#define cudaEventRecord musaEventRecord
+#define cudaEventSynchronize musaEventSynchronize
+#define cudaEvent_t musaEvent_t
+#define cudaEventDestroy musaEventDestroy
+#define cudaFree musaFree
+#define cudaFreeHost musaFreeHost
+#define cudaGetDevice musaGetDevice
+#define cudaGetDeviceCount musaGetDeviceCount
+#define cudaGetDeviceProperties musaGetDeviceProperties
+#define cudaGetErrorString musaGetErrorString
+#define cudaGetLastError musaGetLastError
+#define cudaHostRegister musaHostRegister
+#define cudaHostRegisterPortable musaHostRegisterPortable
+#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
+#define cudaHostUnregister musaHostUnregister
+#define cudaLaunchHostFunc musaLaunchHostFunc
+#define cudaMalloc musaMalloc
+#define cudaMallocHost musaMallocHost
+#define cudaMemcpy musaMemcpy
+#define cudaMemcpyAsync musaMemcpyAsync
+#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
+#define cudaMemcpy2DAsync musaMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
+#define cudaMemcpyKind musaMemcpyKind
+#define cudaMemset musaMemset
+#define cudaMemsetAsync musaMemsetAsync
+#define cudaMemGetInfo musaMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
+#define cudaSetDevice musaSetDevice
+#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
+#define cudaStreamDestroy musaStreamDestroy
+#define cudaStreamFireAndForget musaStreamFireAndForget
+#define cudaStreamNonBlocking musaStreamNonBlocking
+#define cudaStreamPerThread musaStreamPerThread
+#define cudaStreamSynchronize musaStreamSynchronize
+#define cudaStreamWaitEvent musaStreamWaitEvent
+#define cudaStream_t musaStream_t
+#define cudaSuccess musaSuccess
+
+// XXX: Other CUDA => MUSA mapping
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
+#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
+#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
+#define CUdevice MUdevice
+#define CUdeviceptr MUdeviceptr
+#define CUmemAccessDesc MUmemAccessDesc
+#define CUmemAllocationProp MUmemAllocationProp
+#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
+#define cuDeviceGet muDeviceGet
+#define cuDeviceGetAttribute muDeviceGetAttribute
+#define cuMemAddressFree muMemAddressFree
+#define cuMemAddressReserve muMemAddressReserve
+#define cuMemCreate muMemCreate
+#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
+#define cuMemMap muMemMap
+#define cuMemRelease muMemRelease
+#define cuMemSetAccess muMemSetAccess
+#define cuMemUnmap muMemUnmap
+#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
+#define cudaFuncSetAttribute musaFuncSetAttribute
+#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
+#define make_cudaExtent make_musaExtent
+#define make_cudaPitchedPtr make_musaPitchedPtr
+
+// XXX: USE_CUDA_GRAPH
+#define CUDA_SUCCESS MUSA_SUCCESS
+#define CUresult MUresult
+#define cuGetErrorString muGetErrorString
+#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
+#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
+#define cudaGraphDestroy musaGraphDestroy
+#define cudaGraphExecDestroy musaGraphExecDestroy
+#define cudaGraphExec_t musaGraphExec_t
+#define cudaGraphExecUpdate musaGraphExecUpdate
+#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphGetNodes musaGraphGetNodes
+#define cudaGraphInstantiate musaGraphInstantiate
+#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
+#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
+#define cudaGraphLaunch musaGraphLaunch
+#define cudaGraphNodeGetType musaGraphNodeGetType
+#define cudaGraphNode_t musaGraphNode_t
+#define cudaGraphNodeType musaGraphNodeType
+#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
+#define cudaGraph_t musaGraph_t
+#define cudaKernelNodeParams musaKernelNodeParams
+#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamEndCapture musaStreamEndCapture
+
+// XXX: cuBLAS => muBLAS mapping
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define cublasComputeType_t cudaDataType_t
+
+// XXX: Clang builtins mapping
+#define __vsub4   __vsub4_musa
+#define __vcmpeq4 __vcmpeq4_musa
+#define __vcmpne4 __vcmpne4_musa
 #else
 #include <cuda_runtime.h>
 #include <cuda.h>
@@ -168,9 +316,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
 
 #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
 
-#if CUDART_VERSION >= 12000
+#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
     static const char * cublas_get_error_str(const cublasStatus_t err) {
+#ifndef GGML_USE_MUSA
         return cublasGetStatusString(err);
+#else
+        return mublasStatus_to_string(err);
+#endif // GGML_USE_MUSA
     }
 #else
     static const char * cublas_get_error_str(const cublasStatus_t err) {
@@ -200,7 +352,7 @@ static const char * cu_get_error_str(CUresult err) {
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
-#if CUDART_VERSION >= 11100
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
 #else
 #define GGML_CUDA_ASSUME(x)
@@ -214,6 +366,42 @@ typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
 #endif //GGML_CUDA_F16
 
+#if defined(GGML_USE_MUSA)
+#ifndef __has_builtin
+    #define __has_builtin(x) 0
+#endif
+
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+
+static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
+    return __vsubss4(a, b);
+}
+
+static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0xff : 0x00;
+    }
+    return c;
+}
+
+static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
+    }
+    return c;
+}
+#endif // defined(GGML_USE_MUSA)
+
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
 
@@ -455,7 +643,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
     const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
     return mask_low | mask_high;
 }
-#endif // CUDART_VERSION < 12000
+#endif // CUDART_VERSION < CUDART_HMASK
 
 static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)