]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: enable peer access between devices (#2470)
authorJohannes Gäßler <redacted>
Sun, 17 Sep 2023 14:37:53 +0000 (16:37 +0200)
committerGitHub <redacted>
Sun, 17 Sep 2023 14:37:53 +0000 (16:37 +0200)
CMakeLists.txt
Makefile
README.md
ggml-cuda.cu

index abecd684ba2fc0b9d6ea7fa49691c1442873ff88..c0b93564a53dd3ed3861b2368c2b64bf6eba52a0 100644 (file)
@@ -80,6 +80,8 @@ set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kern
 set(LLAMA_CUDA_MMV_Y        "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
 option(LLAMA_CUDA_F16                        "llama: use 16 bit floats for some calculations"   OFF)
 set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
+set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
+                                             "llama: max. batch size for using peer access")
 option(LLAMA_HIPBLAS                         "llama: use hipBLAS"                               OFF)
 option(LLAMA_CLBLAST                         "llama: use CLBlast"                               OFF)
 option(LLAMA_METAL                           "llama: use Metal"                                 ${LLAMA_METAL_DEFAULT})
@@ -304,6 +306,7 @@ if (LLAMA_CUBLAS)
             add_compile_definitions(GGML_CUDA_F16)
         endif()
         add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
+        add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
 
         if (LLAMA_STATIC)
             set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
index 73a9fb17a3ebedf3c7d759b112870cdccef67a47..dc8ae380756530ac88e4d0f46212605c5b3f5e6e 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -368,6 +368,11 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
 else
        NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
 endif
+ifdef LLAMA_CUDA_PEER_MAX_BATCH_SIZE
+       NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(LLAMA_CUDA_PEER_MAX_BATCH_SIZE)
+else
+       NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128
+endif # LLAMA_CUDA_PEER_MAX_BATCH_SIZE
 #ifdef LLAMA_CUDA_CUBLAS
 #      NVCCFLAGS += -DGGML_CUDA_CUBLAS
 #endif # LLAMA_CUDA_CUBLAS
index b3845afd70a462483a46bbaf4085d0c1584dea73..d8fd8bc4478e3bd29cb3bd5219e2499b17d52595 100644 (file)
--- a/README.md
+++ b/README.md
@@ -391,13 +391,14 @@ Building the program with BLAS support may lead to some performance improvements
 <!---
   | LLAMA_CUDA_CUBLAS       | Boolean                |   false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
 --->
-  | Option                  | Legal values           | Default | Description |
-  |-------------------------|------------------------|---------|-------------|
-  | LLAMA_CUDA_FORCE_DMMV   | Boolean                |   false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
-  | LLAMA_CUDA_DMMV_X       | Positive integer >= 32 |      32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
-  | LLAMA_CUDA_MMV_Y        | Positive integer       |       1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
-  | LLAMA_CUDA_F16          | Boolean                |   false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
-  | LLAMA_CUDA_KQUANTS_ITER | 1 or 2                 |       2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
+  | Option                         | Legal values           | Default | Description |
+  |--------------------------------|------------------------|---------|-------------|
+  | LLAMA_CUDA_FORCE_DMMV          | Boolean                |   false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
+  | LLAMA_CUDA_DMMV_X              | Positive integer >= 32 |      32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
+  | LLAMA_CUDA_MMV_Y               | Positive integer       |       1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
+  | LLAMA_CUDA_F16                 | Boolean                |   false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
+  | LLAMA_CUDA_KQUANTS_ITER        | 1 or 2                 |       2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
+  | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer       |     128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
 
 - #### hipBLAS
 
index 248cb2c426f0663f26a41135145367d3db43d647..5346b9e09519ac98fb6709bd2643a9a3e9d4699c 100644 (file)
@@ -31,6 +31,9 @@
 #define cublasSetStream hipblasSetStream
 #define cublasSgemm hipblasSgemm
 #define cublasStatus_t hipblasStatus_t
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
 #define cudaDeviceProp hipDeviceProp_t
 #define cudaDeviceSynchronize hipDeviceSynchronize
 #define cudaError_t hipError_t
@@ -424,6 +427,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
 #endif
 
+#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
+#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
+
 #define MUL_MAT_SRC1_COL_STRIDE 128
 
 #define MAX_STREAMS 8
@@ -6258,6 +6265,41 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
     }
 }
 
+void ggml_cuda_set_peer_access(const int n_tokens) {
+    static bool peer_access_enabled = false;
+
+    const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
+
+    if (peer_access_enabled == enable_peer_access) {
+        return;
+    }
+
+#ifdef NDEBUG
+    for (int id = 0; id < g_device_count; ++id) {
+        CUDA_CHECK(ggml_cuda_set_device(id));
+
+        for (int id_other = 0; id_other < g_device_count; ++id_other) {
+            if (id == id_other) {
+                continue;
+            }
+            if (id != g_main_device && id_other != g_main_device) {
+                continue;
+            }
+
+            int canAccessPeer;
+            CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, id, id_other));
+            if (enable_peer_access) {
+                CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
+            } else {
+                CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other));
+            }
+        }
+    }
+#endif // NDEBUG
+
+    peer_access_enabled = enable_peer_access;
+}
+
 static void ggml_cuda_op_mul_mat(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
     const bool convert_src1_to_q8_1) {
@@ -6282,6 +6324,8 @@ static void ggml_cuda_op_mul_mat(
     const int nb2 = dst->nb[2];
     const int nb3 = dst->nb[3];
 
+    ggml_cuda_set_peer_access(ne11);
+
     GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
 
@@ -7010,7 +7054,7 @@ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
     ggml_cuda_assign_buffers_impl(tensor, false, true, false);
 }
 
-void ggml_cuda_set_main_device(int main_device) {
+void ggml_cuda_set_main_device(const int main_device) {
     if (main_device >= g_device_count) {
         fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
                 main_device, g_device_count, g_main_device);
@@ -7024,11 +7068,11 @@ void ggml_cuda_set_main_device(int main_device) {
     }
 }
 
-void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
+void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
     g_mul_mat_q = mul_mat_q;
 }
 
-void ggml_cuda_set_scratch_size(size_t scratch_size) {
+void ggml_cuda_set_scratch_size(const size_t scratch_size) {
     g_scratch_size = scratch_size;
 }