return false;
}
+// returns whether the write (out) nodes overwrite the read nodes in operation
+static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
+ int node_idx,
+ int node_count,
+ int * out_nodes,
+ int out_count) {
+ auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
+ const int64_t a_start = (int64_t) a->data;
+ const int64_t a_end = a_start + ggml_nbytes(a);
+
+ const int64_t b_start = (int64_t) b->data;
+ const int64_t b_end = b_start + ggml_nbytes(b);
+
+ if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
+ return true;
+ }
+
+ return false;
+ };
+
+ bool is_ok = true;
+ // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
+ if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
+ return true;
+ }
+
+ for (int i = 0; i < out_count; ++i) {
+ const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
+
+ for (int j = node_idx; j < node_idx + node_count; ++j) {
+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
+ // the destination and the src is not an intermediate node that's being
+ // elided, then disable fusion.
+
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+ const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
+
+ if (!src || src->op == GGML_OP_NONE) {
+ continue;
+ }
+
+ if (nodes_overlap(dst, src)) {
+ bool found = false;
+
+ for (int k = node_idx; k < j; ++k) {
+ if (cgraph->nodes[k] == src) {
+ found = true;
+ break;
+ }
+ }
+
+ if (!found) {
+ is_ok = false;
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ return is_ok;
+}
+
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
bool graph_evaluated_or_captured = false;
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
- ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
+ ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
- ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
+ ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;