]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : ROCm AMD Unified Memory Architecture (UMA) handling (#4449)
authorErik Garrison <redacted>
Thu, 21 Dec 2023 19:45:32 +0000 (13:45 -0600)
committerGitHub <redacted>
Thu, 21 Dec 2023 19:45:32 +0000 (21:45 +0200)
* AMD ROCm: handle UMA memory VRAM expansions

This resolves #2797 by allowing ROCm AMD GPU users with a UMA to
dynamically expand the VRAM allocated to the GPU.

Without this, AMD ROCm users with shared CPU/GPU memory usually are
stuck with the BIOS-set (or fixed) framebuffer VRAM, making it
impossible to load more than 1-2 layers.

Note that the model is duplicated in RAM because it's loaded once for
the CPU and then copied into a second set of allocations that are
managed by the HIP UMA system. We can fix this later.

* clarify build process for ROCm on linux with cmake

* avoid using deprecated ROCm hipMallocHost

* keep simplifying the change required for UMA

* cmake: enable UMA-compatible allocation when LLAMA_HIP_UMA=ON

CMakeLists.txt
README.md
ggml-cuda.cu

index e3cd43ab36f0626efbfd74335d5e65a8f81dcaf8..6fc6508c598ff1bf64885c63c4af105ce85a1396 100644 (file)
@@ -91,6 +91,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
 set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
                                              "llama: max. batch size for using peer access")
 option(LLAMA_HIPBLAS                         "llama: use hipBLAS"                               OFF)
+option(LLAMA_HIP_UMA                         "llama: use HIP unified memory architecture"       OFF)
 option(LLAMA_CLBLAST                         "llama: use CLBlast"                               OFF)
 option(LLAMA_METAL                           "llama: use Metal"                                 ${LLAMA_METAL_DEFAULT})
 option(LLAMA_METAL_NDEBUG                    "llama: disable Metal debugging"                   OFF)
@@ -377,6 +378,9 @@ if (LLAMA_HIPBLAS)
     if (${hipblas_FOUND} AND ${hip_FOUND})
         message(STATUS "HIP and hipBLAS found")
         add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
+        if (LLAMA_HIP_UMA)
+            add_compile_definitions(GGML_HIP_UMA)
+        endif()
         add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
         if (BUILD_SHARED_LIBS)
             set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
index 80ce194ca91de294dc5381fa952dc1668e7d1291..73fe59bb40fd32035c8bb070c540d92cba28c923 100644 (file)
--- a/README.md
+++ b/README.md
@@ -432,14 +432,15 @@ Building the program with BLAS support may lead to some performance improvements
     ```bash
     make LLAMA_HIPBLAS=1
     ```
-  - Using `CMake` for Linux:
+  - Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
     ```bash
-    mkdir build
-    cd build
-    CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON
-    cmake --build .
+    CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ \
+        cmake -H. -Bbuild -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
+        && cmake --build build -- -j 16
     ```
-  - Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS):
+    On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DLLAMA_HIP_UMA=ON"`.
+    However, this hurts performance for non-integrated GPUs.
+  - Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
     ```bash
     set PATH=%HIP_PATH%\bin;%PATH%
     mkdir build
@@ -448,10 +449,11 @@ Building the program with BLAS support may lead to some performance improvements
     cmake --build .
     ```
     Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
+    Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
 
 
   The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
-  If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
+  If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
   The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
 
   | Option                  | Legal values           | Default | Description |
index 61d92d7ef61edc01b6577e82f4b6c7367f508b8c..32603a8d16a78d81e1492215ac8f35a46bf8edba 100644 (file)
 #define cudaGetDeviceProperties hipGetDeviceProperties
 #define cudaGetErrorString hipGetErrorString
 #define cudaGetLastError hipGetLastError
+#ifdef GGML_HIP_UMA
+#define cudaMalloc hipMallocManaged
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
+#else
 #define cudaMalloc hipMalloc
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#endif
 #define cudaMemcpy hipMemcpy
 #define cudaMemcpy2DAsync hipMemcpy2DAsync
 #define cudaMemcpyAsync hipMemcpyAsync