]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : synchronize graph capture and cublas handle destruction (llama/14288)
authorDiego Devesa <redacted>
Fri, 20 Jun 2025 11:57:36 +0000 (04:57 -0700)
committerGeorgi Gerganov <redacted>
Sat, 21 Jun 2025 04:34:17 +0000 (07:34 +0300)
Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread

ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/ggml-cuda.cu

index c14a12f54a8d6cb73b0fad9ecc0c6a1c05b60ac9..364efcaeccc0796980417beae7b56b348b7f222e 100644 (file)
 #endif
 #include "ggml-common.h"
 
-#include <cstdio>
 #include <array>
 #include <cassert>
 #include <cfloat>
+#include <cstdio>
 #include <string>
 #include <vector>
 
@@ -767,21 +767,7 @@ struct ggml_backend_cuda_context {
         name(GGML_CUDA_NAME + std::to_string(device)) {
     }
 
-    ~ggml_backend_cuda_context() {
-        if (copy_event != nullptr) {
-            CUDA_CHECK(cudaEventDestroy(copy_event));
-        }
-        for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
-            for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
-                if (streams[i][j] != nullptr) {
-                    CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
-                }
-            }
-            if (cublas_handles[i] != nullptr) {
-                CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
-            }
-        }
-    }
+    ~ggml_backend_cuda_context();
 
     cudaStream_t stream(int device, int stream) {
         if (streams[device][stream] == nullptr) {
index 80fe050734dfa72adecfff8671f4a24b15d7ece5..530f541f97d6297d4a677b7192ea7abfb835b25b 100644 (file)
@@ -48,6 +48,7 @@
 #include <atomic>
 #include <charconv>
 #include <cinttypes>
+#include <condition_variable>
 #include <cstddef>
 #include <cstdint>
 #include <float.h>
@@ -55,9 +56,8 @@
 #include <map>
 #include <memory>
 #include <mutex>
-#include <stdint.h>
-#include <stdio.h>
 #include <stdarg.h>
+#include <stdio.h>
 #include <stdlib.h>
 #include <string>
 #include <vector>
@@ -515,6 +515,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
     return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
 }
 
+// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
+// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
+
+static std::mutex ggml_cuda_lock;
+static std::condition_variable ggml_cuda_lock_cv;
+static std::atomic<int> ggml_cuda_lock_counter;
+
+ggml_backend_cuda_context::~ggml_backend_cuda_context() {
+    std::unique_lock<std::mutex> lock(ggml_cuda_lock);
+    ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
+
+    if (copy_event != nullptr) {
+        CUDA_CHECK(cudaEventDestroy(copy_event));
+    }
+    for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
+        for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
+            if (streams[i][j] != nullptr) {
+                CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
+            }
+        }
+        if (cublas_handles[i] != nullptr) {
+            CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
+        }
+    }
+}
+
+
 // cuda buffer
 
 struct ggml_backend_cuda_buffer_context {
@@ -2689,6 +2716,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
 
             CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
+
+            std::lock_guard<std::mutex> lock(ggml_cuda_lock);
+            if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
+                ggml_cuda_lock_cv.notify_all();
+            }
         } else {
             graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
         }
@@ -2764,7 +2796,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         }
     }
 
-    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+    if (use_cuda_graph && cuda_graph_update_required) {
+        // Start CUDA graph capture
+        {
+            std::lock_guard<std::mutex> lock(ggml_cuda_lock);
+            ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
+        }
+
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }