]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : add LLAMA_CUDA_NO_PEER_COPY to workaround broken ROCm p2p copy (#6208)
authorslaren <redacted>
Fri, 22 Mar 2024 13:05:31 +0000 (14:05 +0100)
committerGitHub <redacted>
Fri, 22 Mar 2024 13:05:31 +0000 (14:05 +0100)
* cuda : add LLAMA_CUDA_NO_PEER_COPY to workaround broken ROCm p2p copy

* add LLAMA_CUDA_NO_PEER_COPY to HIP build

CMakeLists.txt
Makefile
ggml-cuda.cu

index fc4cff28f44ac047b29a9b088ee741e6d75875c9..3333ee1c9d501a8a139d8c765cde8c339c05ed58 100644 (file)
@@ -99,6 +99,7 @@ option(LLAMA_CUDA_F16                        "llama: use 16 bit floats for some
 set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
 set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
                                              "llama: max. batch size for using peer access")
+option(LLAMA_CUDA_NO_PEER_COPY               "llama: do not use peer to peer copies"            OFF)
 option(LLAMA_CURL                            "llama: use libcurl to download model from an URL" OFF)
 option(LLAMA_HIPBLAS                         "llama: use hipBLAS"                               OFF)
 option(LLAMA_HIP_UMA                         "llama: use HIP unified memory architecture"       OFF)
@@ -387,6 +388,9 @@ if (LLAMA_CUBLAS)
         endif()
         add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
         add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
+        if (LLAMA_CUDA_NO_PEER_COPY)
+            add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+        endif()
 
         if (LLAMA_STATIC)
             if (WIN32)
@@ -531,6 +535,10 @@ if (LLAMA_HIPBLAS)
         add_compile_definitions(GGML_CUDA_FORCE_MMQ)
     endif()
 
+    if (LLAMA_CUDA_NO_PEER_COPY)
+        add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+    endif()
+
     add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
     add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
     add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
index 9b72e1dbd54939253a01e6b01530fe711d88791a..fa112e7081fa7fb5aba93844b630c4a2b4ca1112 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -452,9 +452,9 @@ ifdef LLAMA_CUDA_PEER_MAX_BATCH_SIZE
 else
        MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128
 endif # LLAMA_CUDA_PEER_MAX_BATCH_SIZE
-#ifdef LLAMA_CUDA_CUBLAS
-#      MK_NVCCFLAGS += -DGGML_CUDA_CUBLAS
-#endif # LLAMA_CUDA_CUBLAS
+ifdef LLAMA_CUDA_NO_PEER_COPY
+       MK_NVCCFLAGS += -DGGML_CUDA_NO_PEER_COPY
+endif # LLAMA_CUDA_NO_PEER_COPY
 ifdef LLAMA_CUDA_CCBIN
        MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
 endif
@@ -535,6 +535,9 @@ endif # LLAMA_HIP_UMA
 ifdef LLAMA_CUDA_FORCE_DMMV
        HIPFLAGS        += -DGGML_CUDA_FORCE_DMMV
 endif # LLAMA_CUDA_FORCE_DMMV
+ifdef LLAMA_CUDA_NO_PEER_COPY
+       HIPFLAGS        += -DGGML_CUDA_NO_PEER_COPY
+endif # LLAMA_CUDA_NO_PEER_COPY
        OBJS        += ggml-cuda.o
 ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
        $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
index 14f409eb1dd6aae6f5581101c59b081ecf8dffb9..adf930478f5ff3286b1349fe206c62610f72d88a 100644 (file)
@@ -771,7 +771,11 @@ GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t
         if (src_ctx->device == dst_ctx->device) {
             CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
         } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+            return false;
+#else
             CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
+#endif
         }
         CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
         return true;
@@ -11322,19 +11326,23 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
         GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device);
         GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device);
 
-        if (!cuda_ctx_src->copy_event) {
-            ggml_cuda_set_device(cuda_ctx_src->device);
-            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
-        }
-
         // copy on src stream
         if (cuda_ctx_src->device == cuda_ctx_dst->device) {
             CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream()));
         } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+            return false;
+#else
             CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
+#endif
         }
 
         // record event on src stream
+        if (!cuda_ctx_src->copy_event) {
+            ggml_cuda_set_device(cuda_ctx_src->device);
+            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
+        }
+
         CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
 
         // wait on dst stream for the copy to complete
@@ -11530,6 +11538,9 @@ GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const
 }
 
 static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
+#ifdef GGML_CUDA_NO_PEER_COPY
+    return nullptr;
+#else
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
 
     ggml_cuda_set_device(cuda_ctx->device);
@@ -11541,6 +11552,7 @@ static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend)
         /* .backend = */ backend,
         /* .context = */ event,
     };
+#endif
 }
 
 static void ggml_backend_cuda_event_free(ggml_backend_event_t event) {