#include "ggml-common.h"
#include <array>
+#include <algorithm>
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <string>
+#include <unordered_map>
#include <vector>
#if defined(GGML_USE_HIP)
#endif
};
+struct ggml_cuda_concurrent_event {
+ std::vector<cudaEvent_t> join_events;
+ cudaEvent_t fork_event = nullptr;
+
+ int n_streams = 0;
+ std::unordered_map<const ggml_tensor *, int> stream_mapping;
+
+ const ggml_tensor * join_node;
+
+ ggml_cuda_concurrent_event() = default;
+
+ ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
+ ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
+
+ explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
+ join_events.resize(n_streams);
+
+ for (size_t i = 0; i < join_events.size(); ++i) {
+ CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
+ }
+
+ CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
+ }
+
+ ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
+ : join_events(std::move(other.join_events))
+ , fork_event(other.fork_event)
+ , n_streams(other.n_streams)
+ , stream_mapping(std::move(other.stream_mapping))
+ , join_node(other.join_node) {
+ other.fork_event = nullptr;
+ }
+
+ // 1. check if any branches write to overlapping memory ranges (except the join node)
+ // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
+ // we assume all nodes have the same buffer
+ bool is_valid() const {
+ std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
+ write_ranges.resize(n_streams);
+
+ // get join_node's memory range to exclude from overlap checking.
+ // multiple nodes can use join_node's buffer; we synchronize on the join node.
+ const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
+ const int64_t join_start = (int64_t) join_t->data;
+ const int64_t join_end = join_start + ggml_nbytes(join_t);
+
+ for (const auto & [tensor, stream] : stream_mapping) {
+ const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
+ const int64_t t_start = (int64_t) t->data;
+ const int64_t t_end = t_start + ggml_nbytes(t);
+
+ // skip tensors that overlap with join_node's buffer.
+ if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
+ continue;
+ }
+
+ // concurrent streams begin from 1
+ write_ranges[stream - 1].emplace_back(t_start, t_end);
+ }
+
+ for (int i = 0; i < n_streams; ++i) {
+ // sorts first by start then by end of write range
+ std::sort(write_ranges[i].begin(), write_ranges[i].end());
+ }
+
+ bool writes_overlap = false;
+ bool dependent_srcs = false;
+ for (const auto & [tensor, stream] : stream_mapping) {
+ const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
+ const int64_t t_start = (int64_t) t->data;
+ const int64_t t_end = t_start + ggml_nbytes(t);
+
+ // skip tensors that overlap with join_node's buffer
+ if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
+ continue;
+ }
+
+ // check if this buffer's write data overlaps with another stream's
+ std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
+ for (int i = 0; i < n_streams; ++i) {
+ if (i == stream - 1) {
+ continue;
+ }
+ auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
+
+ if (it != write_ranges[i].end()) {
+ const std::pair<int64_t, int64_t> & other = *it;
+
+ // std::lower_bound returns the first element where other >= data_range (lexicographically).
+ // This guarantees other.first >= data_range.first.
+ // Therefore, overlap occurs iff other.first < data_range.second
+ // (i.e., the other range starts before this range ends).
+ if (other.first < data_range.second) {
+ GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
+ writes_overlap = true;
+ break;
+ }
+ }
+ }
+
+ //check if all srcs are either in branch or don't have a branch
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ if (!tensor->src[i]) {
+ continue;
+ }
+
+ auto it = stream_mapping.find(tensor->src[i]);
+
+ if (it == stream_mapping.end()) {
+ continue;
+ }
+
+ if (it->second != stream) {
+ dependent_srcs = true;
+ break;
+ }
+ }
+
+ if (dependent_srcs || writes_overlap) {
+ break;
+ }
+ }
+
+ return !writes_overlap && !dependent_srcs;
+ }
+
+ ~ggml_cuda_concurrent_event() {
+ if (fork_event != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(fork_event));
+ }
+ for (cudaEvent_t e : join_events) {
+ if (e != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(e));
+ }
+ }
+ }
+};
+
+struct ggml_cuda_stream_context {
+ std::vector<const ggml_tensor *> original_nodes;
+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
+
+ void reset() {
+ original_nodes.clear();
+ concurrent_events.clear();
+ }
+};
+
struct ggml_backend_cuda_context {
int device;
std::string name;
std::unique_ptr<ggml_cuda_graph> cuda_graph;
+ int curr_stream_no = 0;
+
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
+ ggml_cuda_stream_context concurrent_stream_context;
+
~ggml_backend_cuda_context();
cudaStream_t stream(int device, int stream) {
return streams[device][stream];
}
- cudaStream_t stream() {
- return stream(device, 0);
- }
+ cudaStream_t stream() { return stream(device, curr_stream_no); }
+
+ ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
}
// pool
- std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
+ std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
- static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
+ static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
ggml_cuda_pool & pool(int device) {
- if (pools[device] == nullptr) {
- pools[device] = new_pool_for_device(device);
+ if (pools[device][curr_stream_no] == nullptr) {
+ pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
}
- return *pools[device];
+ return *pools[device][curr_stream_no];
}
ggml_cuda_pool & pool() {
};
#endif // defined(GGML_USE_VMM)
-std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
+std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
+ [[maybe_unused]] int stream_no) {
#if defined(GGML_USE_VMM)
if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
// flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
+ ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
+ bool is_concurrent_event_active = false;
+ ggml_cuda_concurrent_event * concurrent_event = nullptr;
+ bool should_launch_concurrent_events = false;
+
+ const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
+ if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
+ concurrent_event = &stream_ctx.concurrent_events[node];
+
+ is_concurrent_event_active = true;
+
+ GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
+
+ cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
+ GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
+ CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
+
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
+ cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
+ CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
+ }
+ }
+ };
+
while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
-
[[maybe_unused]] int prev_i = 0;
+ if (stream_ctx.concurrent_events.size() > 0) {
+ should_launch_concurrent_events = true;
+ for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
+ should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
+ }
+ }
+ if (should_launch_concurrent_events) {
+ //Restore the original graph to enable fusion within the streams
+ cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
+ cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
+ }
+
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
+ if (is_concurrent_event_active) {
+ GGML_ASSERT(concurrent_event);
+
+ if (node == concurrent_event->join_node) {
+ cuda_ctx->curr_stream_no = 0;
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
+ // Wait on join events of forked streams in the main stream
+ CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
+ cuda_ctx->stream(cuda_ctx->device, i)));
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
+ }
+
+ is_concurrent_event_active = false;
+ concurrent_event = nullptr;
+ } else {
+ GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
+ }
+ } else if (i - prev_i > 1) {
+ //the previous node was fused
+ const ggml_tensor * prev_node = cgraph->nodes[i - 1];
+ try_launch_concurrent_event(prev_node);
+
+ if (is_concurrent_event_active) {
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
+ }
+ }
+ prev_i = i;
+
#ifdef GGML_CUDA_DEBUG
const int nodes_fused = i - prev_i - 1;
- prev_i = i;
if (nodes_fused > 0) {
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
}
continue;
}
+
+ // start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
}
#else
GGML_UNUSED(integrated);
-#endif // NDEBUG
+#endif // NDEBUG
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
+
+ if (!is_concurrent_event_active) {
+ try_launch_concurrent_event(node);
+ }
}
}
}
}
+static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+
+ static bool enable_graph_optimization = [] {
+ const char * env = getenv("GGML_CUDA_GRAPH_OPT");
+ return env != nullptr && atoi(env) == 1;
+ }();
+
+ if (!enable_graph_optimization) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
+ GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
+
+ ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
+ stream_context.reset();
+
+ // number of out-degrees for a particular node
+ std::unordered_map<const ggml_tensor *, int> fan_out;
+ // reverse mapping of node to index in the cgraph
+ std::unordered_map<const ggml_tensor *, int> node_indices;
+
+ const auto & is_noop = [](const ggml_tensor * node) -> bool {
+ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
+ node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
+ };
+
+ const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
+ if (dst->src[s] == src) {
+ return true;
+ }
+ }
+ // implicit dependency if they view the same tensor
+ const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
+ const ggml_tensor * src2 = src->view_src ? src->view_src : src;
+ if (dst2 == src2) {
+ return true;
+ }
+ return false;
+ };
+
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
+ const ggml_tensor * node = cgraph->nodes[node_idx];
+ node_indices[node] = node_idx;
+
+ if (is_noop(node)) {
+ continue;
+ }
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+ const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
+ //TODO: check why nrows > 1 fails
+ if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
+ fan_out[src] += 1;
+ }
+ }
+ }
+
+ // Target Q, K, V for concurrency
+ // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
+ // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
+ // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
+ // 3. account for all branches from the fork to the join
+ // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
+ // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
+
+ const int min_fan_out = 3;
+ const int max_fan_out = 3;
+
+ // store {fork_idx, join_idx}
+ std::vector<std::pair<int, int>> concurrent_node_ranges;
+
+ // save the original nodes
+ std::vector<const ggml_tensor *> original_nodes;
+ original_nodes.reserve(cgraph->n_nodes);
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ original_nodes.push_back(cgraph->nodes[i]);
+ }
+ cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
+
+ for (const auto & [root_node, count] : fan_out) {
+ if (count >= min_fan_out && count <= max_fan_out) {
+ const int root_node_idx = node_indices[root_node];
+
+ bool is_part_of_event = false;
+ for (const auto & [start, end] : concurrent_node_ranges) {
+ if (root_node_idx >= start && root_node_idx <= end) {
+ is_part_of_event = true;
+ }
+ }
+
+ if (is_part_of_event) {
+ continue;
+ }
+
+ std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
+ const ggml_tensor * node = cgraph->nodes[i];
+ if (!is_noop(node) && depends_on(node, root_node)) {
+ nodes_per_branch.push_back({ node });
+ }
+ }
+
+ GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
+
+ //find the join point
+ const ggml_tensor * join_node = nullptr;
+
+ const auto & belongs_to_branch = [&](const ggml_tensor * node,
+ const std::vector<const ggml_tensor *> & branch) -> bool {
+ for (const ggml_tensor * n : branch) {
+ if (depends_on(node, n)) {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
+ const ggml_tensor * curr_node = cgraph->nodes[i];
+
+ int num_joins = 0;
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+ if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
+ num_joins++;
+ }
+ }
+
+ if (num_joins >= 2) {
+ join_node = curr_node;
+ break;
+ }
+
+ bool found_branch = false;
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+ std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
+ if (belongs_to_branch(curr_node, branch_vec)) {
+ //continue accumulating
+ if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
+ branch_vec.push_back(curr_node);
+ }
+ found_branch = true;
+ }
+ }
+
+ if (!found_branch && is_noop(curr_node)) {
+ // we can put it in any branch because it will be ignored
+ nodes_per_branch[0].push_back({ curr_node });
+ }
+ }
+
+ if (join_node) {
+ //Create ggml_cuda_concurrent_event
+ ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
+ concurrent_event.join_node = join_node;
+
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+ for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
+ concurrent_event.stream_mapping[n] = branch_idx + 1;
+ }
+ }
+
+ int fork_node_idx = node_indices[root_node];
+ int join_node_idx = node_indices[join_node];
+
+ int current_branch_idx = 0;
+ int current_node_idx = fork_node_idx + 1;
+ const int n_branches = nodes_per_branch.size();
+
+ int total_branch_nodes = 0;
+ for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
+ total_branch_nodes += branch_nodes.size();
+ }
+
+ // there are other nodes in the middle which are unaccounted for
+ // usually (cpy) nodes, then ignore this fork
+ if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
+ GGML_LOG_DEBUG(
+ "Skipping %s because the number of nodes in the middle is not equal to the total number of "
+ "branch nodes %d != %d\n",
+ root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
+ continue;
+ }
+
+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
+ GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
+ concurrent_events.emplace(root_node, std::move(concurrent_event));
+ GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
+ concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
+
+ // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
+ // example transformation:
+ // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
+ // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
+ while (current_node_idx < join_node_idx) {
+ std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
+
+ bool has_node = false;
+ for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
+ has_node |= branch_node.size() > 0;
+ }
+
+ GGML_ASSERT(has_node);
+
+ if (branch_nodes.empty()) {
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
+ continue;
+ }
+
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
+ current_node_idx++;
+ branch_nodes.erase(branch_nodes.begin());
+
+ // append all empty nodes
+ while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
+ current_node_idx++;
+ branch_nodes.erase(branch_nodes.begin());
+ }
+
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
+ }
+ }
+ }
+ }
+}
+
static const ggml_backend_i ggml_backend_cuda_interface = {
/* .get_name = */ ggml_backend_cuda_get_name,
/* .free = */ ggml_backend_cuda_free,
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
- /* .graph_optimize = */ NULL,
+ /* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
};
static ggml_guid_t ggml_backend_cuda_guid() {