]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: check for memory overlap before doing fusion (llama/19768)
authorJeff Bolz <redacted>
Wed, 25 Feb 2026 17:25:38 +0000 (11:25 -0600)
committerGeorgi Gerganov <redacted>
Fri, 27 Feb 2026 10:04:54 +0000 (12:04 +0200)
* vulkan: check for memory overlap before doing fusion

* Update ggml/src/ggml-vulkan/ggml-vulkan.cpp

* address feedback

src/ggml-vulkan/ggml-vulkan.cpp

index 8a9cfaf16547afe90d01bdf44f6a40ec2fbbdebe..a1149e606e405dc523c747586b3ff73f7e3801f7 100644 (file)
@@ -13820,12 +13820,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
     return true;
 }
 
-// Check whether the tensors overlap in memory but are not equal.
-// Fusions can potenitally overwrite src tensors in ways that are not prevented
-// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
-// to overlap if they are exactly equal.
-// XXX TODO this check is probably missing from several fusion optimizations.
-static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
+// Check whether the tensors overlap in memory.
+// Fusions can potentially overwrite src tensors in ways that are not prevented
+// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
+// with the destination, then it's OK for them to overlap if they are exactly equal.
+static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {
     ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
     vk_buffer a_buf = a_buf_ctx->dev_buffer;
     ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
@@ -13836,7 +13835,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g
         auto b_base = vk_tensor_offset(b) + b->view_offs;
         auto b_size = ggml_nbytes(b);
 
-        if (a_base == b_base && a_size == b_size) {
+        if (elementwise && a_base == b_base && a_size == b_size) {
             return false;
         }
 
@@ -13874,13 +13873,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
         return false;
     }
 
-    // must not overwrite srcs in a way that's not elementwise
-    ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
-    if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
-        ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
-        return false;
-    }
-
     // conditions for pipeline creation
     if (!(ctx->device->float_controls_rte_fp16 &&
         sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
@@ -13942,6 +13934,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
     return num_adds;
 }
 
+static int32_t find_first_set(uint32_t x) {
+    int32_t ret = 0;
+    if (!x) {
+        return -1;
+    }
+    while (!(x & 1)) {
+        x >>= 1;
+        ret++;
+    }
+    return ret;
+}
+
 static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -14040,6 +14044,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             total_mul_mat_bytes += bytes;
         }
 
+        // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
+        // the fused result in an elementwise-way. This affects whether the memory for
+        // the src is allowed to overlap the memory for the destination.
+        // The array is sized to handle the largest fusion (asserted later).
+        bool op_srcs_fused_elementwise[12];
+
         ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
         ctx->fused_topk_moe_scale = false;
         const char *fusion_string {};
@@ -14048,39 +14058,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             if (num_adds) {
                 ctx->num_additional_fused_ops = num_adds - 1;
                 fusion_string = "MULTI_ADD";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "MUL_MAT_ADD_ADD";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ADD";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ID_ADD_ID";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ID_MUL";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
                        ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
                        ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
                 ctx->num_additional_fused_ops = 4;
                 fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = false;
+                op_srcs_fused_elementwise[2] = false;
+                op_srcs_fused_elementwise[3] = false;
+                op_srcs_fused_elementwise[4] = false;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
                        ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "RMS_NORM_MUL_ROPE";
+                // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "RMS_NORM_MUL";
+                // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before
+                // they are overwritten, and one workgroup per row. So close enough.
+                op_srcs_fused_elementwise[0] = true;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
                        ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "ROPE_VIEW_SET_ROWS";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = false;
+                op_srcs_fused_elementwise[2] = false;
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
@@ -14089,6 +14128,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 3;
                 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
@@ -14097,6 +14137,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 4;
                 ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
                 fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
@@ -14105,6 +14146,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 3;
                 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
@@ -14113,6 +14155,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 1;
                 ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
                 fusion_string = "TOPK_MOE_LATE_SOFTMAX";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             }
             if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
                 // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
@@ -14120,11 +14163,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                     ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
                     ctx->fused_topk_moe_scale = true;
                     ctx->num_additional_fused_ops++;
+                    op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;
                 }
             }
         }
+        GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));
         ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
 
+        // Check whether fusion would overwrite src operands while they're still in use.
+        // If so, disable fusion.
+        if (ctx->num_additional_fused_ops) {
+            // There are up to two output nodes - topk_moe has two.
+            uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);
+            ggml_tensor *output_nodes[2] {};
+            output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];
+            if (bits) {
+                int output_idx = find_first_set(bits);
+                GGML_ASSERT(bits == (1u << output_idx));
+                output_nodes[1] = cgraph->nodes[i + output_idx];
+            }
+
+            bool need_disable = false;
+
+            // topk_moe often overwrites the source, but for a given row all the src values are
+            // loaded before anything is stored. If there's only one row, this is safe, so treat
+            // this as a special case.
+            bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&
+                                          ggml_nrows(cgraph->nodes[i]->src[0]) == 1;
+
+            if (!is_topk_moe_single_row) {
+                for (int j = 0; j < 2; ++j) {
+                    ggml_tensor *dst = output_nodes[j];
+                    if (!dst) {
+                        continue;
+                    }
+                    // Loop over all srcs of all nodes in the fusion. If the src overlaps
+                    // the destination and the src is not an intermediate node that's being
+                    // elided, then disable fusion.
+                    for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {
+                        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
+                            ggml_tensor *src = cgraph->nodes[i + k]->src[s];
+                            if (!src || src->op == GGML_OP_NONE) {
+                                continue;
+                            }
+                            if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {
+                                bool found = false;
+                                for (int n = 0; n < k; ++n) {
+                                    if (cgraph->nodes[i + n] == src) {
+                                        found = true;
+                                        break;
+                                    }
+                                }
+                                if (!found) {
+                                    need_disable = true;
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+            if (need_disable) {
+                ctx->num_additional_fused_ops = 0;
+                ctx->fused_ops_write_mask = 1;
+                ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
+                ctx->fused_topk_moe_scale = false;
+            }
+        }
+
         // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
         bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
         bool submit = (submitted_nodes >= nodes_per_submit) ||