From: bssrdf Date: Fri, 30 Jan 2026 04:57:52 +0000 (-0500) Subject: add tensor type checking as part of cuda graph properties (llama/19186) X-Git-Tag: v0.9.6~5 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=00d7d43e4c54e6cbf63940a8ca668308d0eab63e;p=pkg%2Fggml%2Fsources%2Fggml add tensor type checking as part of cuda graph properties (llama/19186) --- diff --git a/src/ggml-cuda/common.cuh b/src/ggml-cuda/common.cuh index 43280644..a3256d59 100644 --- a/src/ggml-cuda/common.cuh +++ b/src/ggml-cuda/common.cuh @@ -1124,6 +1124,7 @@ struct ggml_tensor_extra_gpu { struct ggml_cuda_graph_node_properties { void * node_data; ggml_op node_op; + enum ggml_type node_type; int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index cfcffde8..e9e9592e 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -2920,6 +2920,7 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); props->node_data = node->data; props->node_op = node->op; + props->node_type = node->type; props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { props->ne[i] = node->ne[i]; @@ -2944,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return false; } + if (node->type != props->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != props->ne[i]) { return false;