]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
add tensor type checking as part of cuda graph properties (llama/19186)
authorbssrdf <redacted>
Fri, 30 Jan 2026 04:57:52 +0000 (23:57 -0500)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
src/ggml-cuda/common.cuh
src/ggml-cuda/ggml-cuda.cu

index 43280644e4865b36d5cafb4812e182cd866270dd..a3256d59dd06319a6397ce6562293b38e31e4f34 100644 (file)
@@ -1124,6 +1124,7 @@ struct ggml_tensor_extra_gpu {
 struct ggml_cuda_graph_node_properties {
     void * node_data;
     ggml_op node_op;
+    enum ggml_type node_type;
     int32_t flags;
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
index cfcffde8a21fcf79646fd7d63a7fa5730c07c897..e9e9592ebad1eb4450b80ebf5d13f52ceb280a23 100644 (file)
@@ -2920,6 +2920,7 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties
     memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
     props->node_data = node->data;
     props->node_op = node->op;
+    props->node_type = node->type;
     props->flags = node->flags;
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         props->ne[i] = node->ne[i];
@@ -2944,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
         return false;
     }
 
+    if (node->type != props->node_type) {
+        return false;
+    }
+
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         if (node->ne[i] != props->ne[i]) {
             return false;