int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;
+ // Original order of nodes in this concurrent region (before interleaving)
+ // Used to restore grouping for fusion within streams
+ std::vector<const ggml_tensor *> original_order;
+
const ggml_tensor * join_node;
ggml_cuda_concurrent_event() = default;
, fork_event(other.fork_event)
, n_streams(other.n_streams)
, stream_mapping(std::move(other.stream_mapping))
+ , original_order(std::move(other.original_order))
, join_node(other.join_node) {
other.fork_event = nullptr;
}
};
struct ggml_cuda_stream_context {
- std::vector<const ggml_tensor *> original_nodes;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
void reset() {
- original_nodes.clear();
concurrent_events.clear();
}
};
}
}
if (should_launch_concurrent_events) {
- //Restore the original graph to enable fusion within the streams
- cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
- cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
+ // Restore original node order within each concurrent region to enable fusion within streams
+
+ std::unordered_map<const ggml_tensor *, int> node_to_idx;
+ node_to_idx.reserve(cgraph->n_nodes);
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ node_to_idx[cgraph->nodes[i]] = i;
+ }
+
+ for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
+ // Find positions of all nodes from this event in the current graph
+ std::vector<int> positions;
+ positions.reserve(event.original_order.size());
+
+ bool all_found = true;
+ for (const ggml_tensor * orig_node : event.original_order) {
+ auto it = node_to_idx.find(orig_node);
+ if (it != node_to_idx.end()) {
+ positions.push_back(it->second);
+ } else {
+ all_found = false;
+ break;
+ }
+ }
+
+ if (!all_found || positions.size() != event.original_order.size()) {
+ continue;
+ }
+
+ // Sort positions to get contiguous range
+ std::vector<int> sorted_positions = positions;
+ std::sort(sorted_positions.begin(), sorted_positions.end());
+
+ bool is_contiguous = true;
+ for (size_t i = 1; i < sorted_positions.size(); ++i) {
+ if (sorted_positions[i] != sorted_positions[i-1] + 1) {
+ is_contiguous = false;
+ break;
+ }
+ }
+
+ if (!is_contiguous) {
+ continue;
+ }
+
+ // Restore original order at the sorted positions
+ int start_pos = sorted_positions[0];
+ for (size_t i = 0; i < event.original_order.size(); ++i) {
+ cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
+ }
+ }
}
for (int i = 0; i < cgraph->n_nodes; i++) {
// store {fork_idx, join_idx}
std::vector<std::pair<int, int>> concurrent_node_ranges;
- // save the original nodes
- std::vector<const ggml_tensor *> original_nodes;
- original_nodes.reserve(cgraph->n_nodes);
- for (int i = 0; i < cgraph->n_nodes; ++i) {
- original_nodes.push_back(cgraph->nodes[i]);
- }
- cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
-
for (const auto & [root_node, count] : fan_out) {
if (count >= min_fan_out && count <= max_fan_out) {
const int root_node_idx = node_indices[root_node];
continue;
}
+ // Save the original order of nodes in this region before interleaving
+ // This is used later to restore grouping for fusion within streams
+ concurrent_event.original_order.reserve(total_branch_nodes);
+ for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
+ concurrent_event.original_order.push_back(cgraph->nodes[i]);
+ }
+
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
concurrent_events.emplace(root_node, std::move(concurrent_event));