]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : fix nkvo, offload and cuda graph node properties matching (#19165)
authorGeorgi Gerganov <redacted>
Thu, 29 Jan 2026 16:45:30 +0000 (18:45 +0200)
committerGitHub <redacted>
Thu, 29 Jan 2026 16:45:30 +0000 (18:45 +0200)
* cuda : fix nkvo

* cont : more robust cuda graph node property matching

* cont : restore pre-leafs implementation

* cont : comments + static_assert

ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/ggml-cuda.cu
src/llama-graph.cpp

index 3335f443aeb4106b940cdb720dabebf339ce124a..43280644e4865b36d5cafb4812e182cd866270dd 100644 (file)
@@ -1122,15 +1122,17 @@ struct ggml_tensor_extra_gpu {
 #endif
 
 struct ggml_cuda_graph_node_properties {
-    void * node_address;
+    void * node_data;
     ggml_op node_op;
     int32_t flags;
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
-    void * src_address[GGML_MAX_SRC];
+    void * src_data[GGML_MAX_SRC];
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 };
 
+static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
+
 struct ggml_cuda_graph {
 #ifdef USE_CUDA_GRAPH
     ~ggml_cuda_graph() {
@@ -1150,6 +1152,12 @@ struct ggml_cuda_graph {
     int number_consecutive_updates = 0;
     std::vector<ggml_cuda_graph_node_properties> props;
 
+    // these are extra tensors (inputs) that participate in the ggml graph but are not nodes
+    // they properties also have to match in order to be able to safely reuse a CUDA graph
+    // ref: https://github.com/ggml-org/llama.cpp/pull/18583
+    // ref: https://github.com/ggml-org/llama.cpp/pull/19165
+    std::vector<ggml_cuda_graph_node_properties> extra;
+
     void record_update(bool use_graph, bool update_required) {
         if (use_graph && update_required) {
             number_consecutive_updates++;
index 195904ee2061eb12566f7f2aa64ea7c5b03fc18d..721edd9994407d62b22e4cba939fbb3b6a1916ab 100644 (file)
@@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         }
     }
 
-    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
-
     const int cc = ggml_cuda_info().devices[device].cc;
 
     switch (K->ne[0]) {
@@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (!gqa_opt_applies) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!V_is_K_view) {
-                return BEST_FATTN_KERNEL_NONE;
-            }
             break;
         default:
             return BEST_FATTN_KERNEL_NONE;
index 76d0f12550e1e306f6a790c32ab800b585efec4d..cfcffde8a21fcf79646fd7d63a7fa5730c07c897 100644 (file)
 #include <condition_variable>
 #include <cstddef>
 #include <cstdint>
-#include <float.h>
+#include <cfloat>
 #include <initializer_list>
 #include <limits>
 #include <map>
 #include <memory>
 #include <mutex>
-#include <stdarg.h>
-#include <stdio.h>
-#include <stdlib.h>
+#include <cstdarg>
+#include <cstdio>
+#include <cstdlib>
 #include <string>
 #include <vector>
+#include <unordered_set>
 
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
@@ -2916,7 +2917,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
 }
 
 static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
-    props->node_address = node->data;
+    memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
+    props->node_data = node->data;
     props->node_op = node->op;
     props->flags = node->flags;
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
@@ -2924,14 +2926,17 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties
         props->nb[i] = node->nb[i];
     }
     for (int i = 0; i < GGML_MAX_SRC; i++) {
-        props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+        if (!node->src[i]) {
+            continue;
+        }
+
+        props->src_data[i] = node->src[i]->data;
     }
     memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
 }
 
 static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
-    if (node->data != props->node_address &&
-          node->op != GGML_OP_VIEW) {
+    if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
         return false;
     }
 
@@ -2948,12 +2953,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
         }
     }
 
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i] &&
-            node->src[i]->data != props->src_address[i] &&
-            node->op != GGML_OP_VIEW
-        ) {
-            return false;
+    if (node->op != GGML_OP_VIEW) {
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (!node->src[i]) {
+                if (props->src_data[i] != nullptr) {
+                    return false;
+                }
+                continue;
+            }
+
+            if (node->src[i]->data != props->src_data[i]) {
+                return false;
+            }
         }
     }
 
@@ -2974,7 +2985,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
 }
 
 static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
-
     bool res = false;
 
     const void * graph_key = ggml_cuda_graph_get_key(cgraph);
@@ -2985,15 +2995,20 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
     }
 
     // Check if the graph size has changed
-    if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
+    if (graph->props.size() != (size_t)cgraph->n_nodes) {
         res = true;
-        graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
+        graph->props.resize(cgraph->n_nodes);
     }
 
     // Loop over nodes in GGML graph to determine if CUDA graph update is required
     // and store properties to allow this comparison for the next token
+    std::unordered_set<ggml_tensor *> seen_node;
+    std::vector<ggml_tensor *> srcs_extra;
     for (int i = 0; i < cgraph->n_nodes; i++) {
         bool props_match = true;
+
+        seen_node.insert(cgraph->nodes[i]);
+
         if (!res) {
             props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
         }
@@ -3001,17 +3016,31 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
             res = true;
         }
         ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
+
+        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+            ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
+            if (src && seen_node.find(src) == seen_node.end()) {
+                srcs_extra.push_back(src);
+            }
+        }
+    }
+
+    if (graph->extra.size() != (size_t) srcs_extra.size()) {
+        res = true;
+        graph->extra.resize(srcs_extra.size());
     }
 
-    for (int i = 0; i < cgraph->n_leafs; i++) {
+    for (size_t i = 0; i < srcs_extra.size(); ++i) {
         bool props_match = true;
+
         if (!res) {
-            props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
+            props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
         }
+
         if (!props_match) {
             res = true;
         }
-        ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
+        ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
     }
 
     return res;
index b3198b7e3a22b6a0fbb2541b684877dc9de0be05..16d42c4ae3de244befd15f72fa798dea6983f0a8 100644 (file)
@@ -1630,11 +1630,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
         cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
 
-        if (!cparams.offload_kqv) {
-            // all nodes between the KV store and the attention output are run on the CPU
-            ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
-        }
-
         ggml_flash_attn_ext_add_sinks(cur, sinks);
         ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);