]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
musa: Upgrade MUSA SDK version to rc4.0.1 and use mudnn::Unary::IDENTITY op to accele...
authorR0CKSTAR <redacted>
Wed, 21 May 2025 01:58:49 +0000 (09:58 +0800)
committerGitHub <redacted>
Wed, 21 May 2025 01:58:49 +0000 (09:58 +0800)
* musa: fix build warning (unused parameter)

Signed-off-by: Xiaodong Ye <redacted>
* musa: upgrade MUSA SDK version to rc4.0.1

Signed-off-by: Xiaodong Ye <redacted>
* musa: use mudnn::Unary::IDENTITY op to accelerate D2D memory copy

Signed-off-by: Xiaodong Ye <redacted>
* Update ggml/src/ggml-cuda/cpy.cu

Co-authored-by: Johannes Gäßler <redacted>
* musa: remove MUDNN_CHECK_GEN and use CUDA_CHECK_GEN instead in MUDNN_CHECK

Signed-off-by: Xiaodong Ye <redacted>
---------

Signed-off-by: Xiaodong Ye <redacted>
Co-authored-by: Johannes Gäßler <redacted>
.devops/musa.Dockerfile
.github/workflows/build.yml
README.md
ci/README.md
docs/docker.md
ggml/src/ggml-cuda/cpy.cu
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-musa/CMakeLists.txt
ggml/src/ggml-musa/mudnn.cu [new file with mode: 0644]
ggml/src/ggml-musa/mudnn.cuh [new file with mode: 0644]

index e0f1ad9728b0991b4df27cf9f7ac64cc15261aa6..87ce2393f6bf9b1ef7acbfedb05c3eac32cb405d 100644 (file)
@@ -1,10 +1,10 @@
 ARG UBUNTU_VERSION=22.04
 # This needs to generally match the container host's environment.
-ARG MUSA_VERSION=rc3.1.1
+ARG MUSA_VERSION=rc4.0.1
 # Target the MUSA build image
-ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
+ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION}
 
-ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
+ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION}
 
 FROM ${BASE_MUSA_DEV_CONTAINER} AS build
 
@@ -21,21 +21,14 @@ RUN apt-get update && \
     libcurl4-openssl-dev \
     libgomp1
 
-COPY requirements.txt   requirements.txt
-COPY requirements       requirements
-
-RUN pip install --upgrade pip setuptools wheel \
-    && pip install -r requirements.txt
-
 WORKDIR /app
 
 COPY . .
 
-# Use the default MUSA archs if not specified
 RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
         export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
     fi && \
-    cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
+    cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
     cmake --build build --config Release -j$(nproc)
 
 RUN mkdir -p /app/lib && \
index b60629b2415110d5c15fd88a5cdf7a2bf2081f4e..ee76d1799e6f474a19f65cdbe2e218c290d50b79 100644 (file)
@@ -351,7 +351,7 @@ jobs:
 
   ubuntu-22-cmake-musa:
     runs-on: ubuntu-22.04
-    container: mthreads/musa:rc3.1.1-devel-ubuntu22.04
+    container: mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
 
     steps:
       - name: Clone
index 5472f7abdeb21d997c8db82c746281367586a624..d1cb8d8336229233c6cdbcdf86e0815e1f6b1032 100644 (file)
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ range of hardware - locally and in the cloud.
 - Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
 - AVX, AVX2, AVX512 and AMX support for x86 architectures
 - 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
-- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA)
+- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
 - Vulkan and SYCL backend support
 - CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity
 
@@ -237,7 +237,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
 | [BLAS](docs/build.md#blas-build) | All |
 | [BLIS](docs/backend/BLIS.md) | All |
 | [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU |
-| [MUSA](docs/build.md#musa) | Moore Threads MTT GPU |
+| [MUSA](docs/build.md#musa) | Moore Threads GPU |
 | [CUDA](docs/build.md#cuda) | Nvidia GPU |
 | [HIP](docs/build.md#hip) | AMD GPU |
 | [Vulkan](docs/build.md#vulkan) | GPU |
index ec3f44350394a56859881a621214cf16d2fe5fc2..6e297f1a82788096aa803848766a350c7fa81af8 100644 (file)
@@ -54,7 +54,7 @@ docker run --privileged -it \
     -v $HOME/llama.cpp/ci-cache:/ci-cache \
     -v $HOME/llama.cpp/ci-results:/ci-results \
     -v $PWD:/ws -w /ws \
-    mthreads/musa:rc3.1.1-devel-ubuntu22.04
+    mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
 ```
 
 Inside the container, execute the following commands:
index 3f4d0cc4fafa7e4e93506254a46bcc05e2590bd0..f8f0573c17239cd69a88bbcb3452bf3eaaa48a7a 100644 (file)
@@ -107,7 +107,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment
 
 The defaults are:
 
-- `MUSA_VERSION` set to `rc3.1.1`
+- `MUSA_VERSION` set to `rc4.0.1`
 
 The resulting images, are essentially the same as the non-MUSA images:
 
index d027271fcd932d895bdda98cd52a7293ae262ca3..2c55d2149b2d32bbbebbaeb063c2f09d8bddd6c6 100644 (file)
@@ -1,5 +1,8 @@
 #include "cpy.cuh"
 #include "dequantize.cuh"
+#ifdef GGML_USE_MUSA
+#include "ggml-musa/mudnn.cuh"
+#endif // GGML_USE_MUSA
 
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 
@@ -597,7 +600,14 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
 #endif
     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
-        CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
+#ifdef GGML_USE_MUSA
+        if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
+            CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
+        } else
+#endif // GGML_USE_MUSA
+        {
+            CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
+        }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
index be0329d0e0c0953fb76077aafc3d8c7e753a5235..7120053b6ee01efe73b729ee83f1e352c0ce12d8 100644 (file)
@@ -772,7 +772,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
     GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
-    GGML_UNUSED(kb0);
+    GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
index 92f05d5558c80bff0c1cf9f5bf272fee0e2ef80e..971314debc714ffb805ded0d9d4ac1e0101f0024 100644 (file)
@@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND)
 
     file(GLOB   GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
     list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
+    list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
 
     file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
     list(APPEND GGML_SOURCES_MUSA ${SRCS})
     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_MUSA ${SRCS})
+    file(GLOB   SRCS "../ggml-musa/*.cu")
+    list(APPEND GGML_SOURCES_MUSA ${SRCS})
 
     if (GGML_CUDA_FA_ALL_QUANTS)
         file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -62,7 +65,9 @@ if (MUSAToolkit_FOUND)
                             )
 
     # TODO: do not use CUDA definitions for MUSA
-    target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+    if (NOT GGML_BACKEND_DL)
+        target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+    endif()
 
     add_compile_definitions(GGML_USE_MUSA)
     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
@@ -92,9 +97,10 @@ if (MUSAToolkit_FOUND)
     endif()
 
     if (GGML_STATIC)
+        # TODO: mudnn has not provided static libraries yet
         target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
     else()
-        target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
+        target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
     endif()
 
     if (GGML_CUDA_NO_VMM)
diff --git a/ggml/src/ggml-musa/mudnn.cu b/ggml/src/ggml-musa/mudnn.cu
new file mode 100644 (file)
index 0000000..020c170
--- /dev/null
@@ -0,0 +1,112 @@
+#include <mutex>
+#include <mudnn.h>
+
+#include "mudnn.cuh"
+
+namespace mudnn = musa::dnn;
+
+// Returns a human-readable error string for mudnn::Status
+const char* mudnnGetErrorString(mudnn::Status err) {
+    switch (err) {
+        case mudnn::Status::SUCCESS:
+            return "Success";
+        case mudnn::Status::INVALID_PARAMETER:
+            return "Invalid parameter";
+        case mudnn::Status::NOT_INITIALIZED:
+            return "Not initialized";
+        case mudnn::Status::ALLOC_FAILED:
+            return "Allocation failed";
+        case mudnn::Status::NOT_SUPPORTED:
+            return "Not supported";
+        case mudnn::Status::INTERNAL_ERROR:
+            return "Internal error";
+        case mudnn::Status::ARCH_MISMATCH:
+            return "Architecture mismatch";
+        case mudnn::Status::EXECUTION_FAILED:
+            return "Execution failed";
+        default:
+            return "Unknown mudnn status";
+    }
+}
+
+// Error checking macro for MUDNN calls
+#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
+
+namespace {
+    // Thread-safe cache for mudnn::Handle objects per device
+    std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
+    std::mutex handle_cache_mutex;
+
+    mudnn::Handle* get_cached_handle(int device_id) {
+        std::lock_guard<std::mutex> lock(handle_cache_mutex);
+        auto it = handle_cache.find(device_id);
+        if (it != handle_cache.end()) {
+            return it->second.get();
+        }
+        auto handle = std::make_unique<mudnn::Handle>(device_id);
+        mudnn::Handle* handle_ptr = handle.get();
+        handle_cache[device_id] = std::move(handle);
+        return handle_ptr;
+    }
+}
+
+// Extracts dimensions and strides from a ggml_tensor
+int get_ggml_dims_and_strides(const ggml_tensor* tensor,
+                              std::vector<int64_t>& dims,
+                              std::vector<int64_t>& strides) {
+    const int ndims = ggml_n_dims(tensor);
+    const size_t element_size = ggml_element_size(tensor);
+
+    dims.resize(ndims);
+    strides.resize(ndims);
+
+    for (int i = 0; i < ndims; ++i) {
+        dims[i] = tensor->ne[i];
+        strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
+    }
+    return ndims;
+}
+
+// Converts ggml_type to mudnn::Tensor::Type
+mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_F32:
+            return mudnn::Tensor::Type::FLOAT;
+        case GGML_TYPE_F16:
+            return mudnn::Tensor::Type::HALF;
+
+        // TODO: Add support for other types
+
+        default:
+            MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
+    }
+
+    return mudnn::Tensor::Type::FLOAT; // Default fallback
+}
+
+// Asynchronous memory copy using mudnn::Unary::IDENTITY
+musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
+    mudnn::Tensor tensor_dst, tensor_src;
+
+    MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
+    MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
+
+    std::vector<int64_t> dims, strides;
+    const int ndims = get_ggml_dims_and_strides(src, dims, strides);
+
+    MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
+    MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
+    MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
+    MUDNN_CHECK(tensor_src.SetAddr(src->data));
+
+    mudnn::Unary op;
+    MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
+    MUDNN_CHECK(op.SetAlpha(0.0f));
+    MUDNN_CHECK(op.SetBeta(0.0f));
+
+    mudnn::Handle* handle = get_cached_handle(ctx.device);
+    MUDNN_CHECK(handle->SetStream(ctx.stream()));
+    MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
+
+    return musaSuccess;
+}
diff --git a/ggml/src/ggml-musa/mudnn.cuh b/ggml/src/ggml-musa/mudnn.cuh
new file mode 100644 (file)
index 0000000..a63be57
--- /dev/null
@@ -0,0 +1,12 @@
+#pragma once
+
+#include "../include/ggml.h"
+#include "../ggml-cuda/common.cuh"
+
+// Asynchronously copies data from src tensor to dst tensor using the provided context.
+// Returns a musaError_t indicating success or failure.
+musaError_t mudnnMemcpyAsync(
+    ggml_backend_cuda_context &ctx,
+    const ggml_tensor *dst,
+    const ggml_tensor *src
+);