]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: disable cuda graph when using n-cpu-moe (#18593)
authorAman Gupta <redacted>
Sun, 4 Jan 2026 17:37:48 +0000 (01:37 +0800)
committerGitHub <redacted>
Sun, 4 Jan 2026 17:37:48 +0000 (01:37 +0800)
* CUDA: disable cuda graph when using n-cpu-moe

* call ggml_cuda_set_device

ggml/src/ggml-cuda/ggml-cuda.cu

index f05d5562ba0483b56375bdc748b96a9dec77ba8f..80d983f9eef078374c4aa40bcde8bf6419f08653 100644 (file)
@@ -3696,6 +3696,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
 }
 
 static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
+
 #ifdef USE_CUDA_GRAPH
     static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
 
@@ -3736,17 +3737,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
 static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
 
+    ggml_cuda_set_device(cuda_ctx->device);
+
     bool use_cuda_graph             = false;
     bool cuda_graph_update_required = false;
 
     // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
     // we call it here instead.
 #ifdef USE_CUDA_GRAPH
-    if (!cuda_ctx->cuda_graph) {
-        use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
-    } else {
-        use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph->cuda_graphs_enabled;
-    }
+    use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
 
     if (use_cuda_graph) {
         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
@@ -3762,6 +3761,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
         if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
             cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+            cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
 #endif