]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
musa: enable building fat binaries, enable unified memory, and disable Flash Attentio...
authorR0CKSTAR <redacted>
Sun, 22 Sep 2024 14:55:49 +0000 (22:55 +0800)
committerGitHub <redacted>
Sun, 22 Sep 2024 14:55:49 +0000 (16:55 +0200)
* mtgpu: add mp_21 support

Signed-off-by: Xiaodong Ye <redacted>
* mtgpu: disable flash attention on qy1 (MTT S80); disable q3_k and mul_mat_batched_cublas

Signed-off-by: Xiaodong Ye <redacted>
* mtgpu: enable unified memory

Signed-off-by: Xiaodong Ye <redacted>
* mtgpu: map cublasOperation_t to mublasOperation_t (sync code to latest)

Signed-off-by: Xiaodong Ye <redacted>
---------

Signed-off-by: Xiaodong Ye <redacted>
Makefile
ggml/src/CMakeLists.txt
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn-tile-f32.cu
ggml/src/ggml-cuda/vendors/musa.h

index f922f7083b7c980104773947767394cb6f88b28b..8a903d7ed5914e4fdbd65ae7741de2a40ce89e53 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -611,7 +611,7 @@ ifdef GGML_CUDA
 
                MK_CPPFLAGS  += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
                MK_LDFLAGS   += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
-               MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
+               MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22
        else
                ifneq ('', '$(wildcard /opt/cuda)')
                        CUDA_PATH ?= /opt/cuda
index 527c22c6867386777c3928dcbb213645b5819e62..6c691a4c590a6ed7d1c154bd57a5f584a89d3866 100644 (file)
@@ -364,7 +364,7 @@ if (GGML_CUDA)
         if (GGML_MUSA)
             set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
             foreach(SOURCE ${GGML_SOURCES_CUDA})
-                set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
+                set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
             endforeach()
         endif()
 
index 5bd4660c3fad7cbbc878699d83a52c70743c63f0..a0d2561009f5837261d29779639ccc82c7b7d8b5 100644 (file)
@@ -136,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
     return res;
 #else
 
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS)
     cudaError_t err;
     if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
     {
@@ -149,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
     return err;
 #else
     return cudaMalloc(ptr, size);
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS)
 
 #endif
 }
@@ -2830,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
                     return false;
                 }
+#ifdef GGML_USE_MUSA
+                if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
+                    !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
+                    return false;
+                }
+#endif // GGML_USE_MUSA
                 switch (a->type) {
                     case GGML_TYPE_F32:
                     case GGML_TYPE_F16:
@@ -2853,6 +2859,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                     case GGML_TYPE_IQ3_XXS:
                     case GGML_TYPE_IQ4_NL:
                     case GGML_TYPE_IQ4_XS:
+#ifdef GGML_USE_MUSA
+                        if (a->type == GGML_TYPE_Q3_K) {
+                            return false;
+                        }
+#endif // GGML_USE_MUSA
                         return true;
                     default:
                         return false;
@@ -2978,6 +2989,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_RWKV_WKV:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
+#ifndef FLASH_ATTN_AVAILABLE
+            return false;
+#endif
             if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
                 return true;
             }
index 85eb200f03b06a3f0de01d17411d83faaf94846c..6a4bcdba095736ae98396566936f7a024b0a02c7 100644 (file)
@@ -50,6 +50,8 @@
 #define CC_RDNA1      (CC_OFFSET_AMD + 1010)
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 #define CC_RDNA3      (CC_OFFSET_AMD + 1100)
+#define CC_QY1        210
+#define CC_QY2        220
 
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -134,6 +136,10 @@ typedef float2 dfloat2;
 #define INT8_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
 
+#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+#define FLASH_ATTN_AVAILABLE
+#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+
 static constexpr bool fast_fp16_available(const int cc) {
     return cc >= CC_PASCAL && cc != 610;
 }
index 827437ca0ad1ff6efdbaef3fba2d49bf7008ccc1..f402195ce0b774be63b765666f64e9b7594e2ec6 100644 (file)
@@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
+#ifndef FLASH_ATTN_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FLASH_ATTN_AVAILABLE
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
         return;
     }
 
-    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
     const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
index 8df571149f19c9dd534648c39ff509846b2516e4..1604b8229d57fe3deb765dfbef42811cc739099c 100644 (file)
@@ -26,6 +26,7 @@
 #define cublasSetStream mublasSetStream
 #define cublasSgemm mublasSgemm
 #define cublasStatus_t mublasStatus_t
+#define cublasOperation_t mublasOperation_t
 #define cublasGetStatusString mublasStatus_to_string
 #define cudaDataType_t musaDataType_t
 #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
@@ -56,6 +57,7 @@
 #define cudaLaunchHostFunc musaLaunchHostFunc
 #define cudaMalloc musaMalloc
 #define cudaMallocHost musaMallocHost
+#define cudaMallocManaged musaMallocManaged
 #define cudaMemcpy musaMemcpy
 #define cudaMemcpyAsync musaMemcpyAsync
 #define cudaMemcpyPeerAsync musaMemcpyPeerAsync