]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Update CUDA graph on scale change plus clear nodes/params (llama/9550)
authoragray3 <redacted>
Sat, 21 Sep 2024 00:41:07 +0000 (01:41 +0100)
committerGeorgi Gerganov <redacted>
Tue, 24 Sep 2024 10:04:37 +0000 (13:04 +0300)
* Avoid using saved CUDA graph if scale changes and reset nodes/params on update

Fixes https://github.com/ggerganov/llama.cpp/issues/9451

* clear before resize

src/ggml-cuda.cu
src/ggml-cuda/common.cuh

index b0843dc621cf54989ab7d01821ecd05683e376b4..895ba479483e71baf6f806270a7068ae1fc99982 100644 (file)
@@ -2478,6 +2478,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
     }
+    memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
 }
 
 static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2509,6 +2510,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
             return false;
         }
     }
+
+    if (node->op == GGML_OP_SCALE &&
+        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+        return false;
+    }
+
     return true;
 }
 
@@ -2720,7 +2727,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
             // First call with null argument gets number of nodes in graph
             CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
             // Subsequent call with non-null argument gets nodes
+            cuda_ctx->cuda_graph->nodes.clear();
             cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+            cuda_ctx->cuda_graph->params.clear();
             cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
             if (cuda_ctx->cuda_graph->num_nodes > 0) {
                 CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
index eb39b6d23a6b3f21ef6b0fa91482c0fdf9b2e1c3..85eb200f03b06a3f0de01d17411d83faaf94846c 100644 (file)
@@ -569,6 +569,7 @@ struct ggml_graph_node_properties {
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
     void * src_address[GGML_MAX_SRC];
+    int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 };
 
 struct ggml_cuda_graph {