From: Georgi Gerganov Date: Sat, 13 Sep 2025 10:54:28 +0000 (+0300) Subject: metal : allow ops to run concurrently (llama/15929) X-Git-Tag: v0.9.1~35 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=0cf545941e739639e9b21d5ba6410daa67baa203;p=pkg%2Fggml%2Fsources%2Fggml metal : allow ops to run concurrently (llama/15929) * metal : run graphs ops concurrently ggml-ci * cont : add flags for debugging and disabling concurrency ggml-ci * cont : refactor and handle fusing ggml-ci * cont : simplify - no need to use GPU address ggml-ci * cont : prepare mem ranges for reuse + add ggml-metal-common.cpp ggml-ci * cont : avoid redundant keywords in cpp [no ci] * metal : reorder graph for better concurrency ggml-ci * metal : fix race on mem pool buffers ggml-ci * cont : add env GGML_METAL_GRAPH_OPTIMIZE_DISABLE ggml-ci * cont : refactor, optimize, add comments ggml-ci * cont : refactor ggml-metal.m ggml-ci * minor : update logs [no ci] --- diff --git a/src/ggml-metal/CMakeLists.txt b/src/ggml-metal/CMakeLists.txt index 0ca8a3c5..65c131b6 100644 --- a/src/ggml-metal/CMakeLists.txt +++ b/src/ggml-metal/CMakeLists.txt @@ -6,6 +6,7 @@ message(STATUS "Metal framework found") ggml_add_backend_library(ggml-metal ggml-metal.m + ggml-metal-common.cpp ) target_link_libraries(ggml-metal PRIVATE diff --git a/src/ggml-metal/ggml-metal-common.cpp b/src/ggml-metal/ggml-metal-common.cpp new file mode 100644 index 00000000..6a869ff2 --- /dev/null +++ b/src/ggml-metal/ggml-metal-common.cpp @@ -0,0 +1,445 @@ +#include "ggml-metal-common.h" + +#include "ggml-impl.h" + +#include + +struct ggml_mem_range { + uint64_t pb; // buffer id + + uint64_t p0; // begin + uint64_t p1; // end + + ggml_mem_range_type pt; +}; + +struct ggml_mem_ranges { + std::vector ranges; + + int debug = 0; +}; + +struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) { + auto * res = new ggml_mem_ranges; + + res->ranges.reserve(256); + res->debug = debug; + + return res; +} + +void ggml_mem_ranges_free(ggml_mem_ranges * mrs) { + delete mrs; +} + +void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) { + mrs->ranges.clear(); +} + +static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) { + mrs->ranges.push_back(mrp); + + return true; +} + +static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) { + // always use the base tensor + tensor = tensor->view_src ? tensor->view_src : tensor; + + GGML_ASSERT(!tensor->view_src); + + ggml_mem_range mrp; + + if (tensor->buffer) { + // when the tensor is allocated, use the actual memory address range of the buffer + mrp = { + /*.pb =*/ (uint64_t) tensor->buffer, + /*.p0 =*/ (uint64_t) tensor->data, + /*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor), + /*.pt =*/ pt, + }; + } else { + // otherwise, the tensor ptr is used as an unique id of the memory ranges + // that the tensor will be using when it is allocated + mrp = { + /*.pb =*/ (uint64_t) tensor, + /*.p0 =*/ 0, // + /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used + /*.pt =*/ pt, + }; + }; + + return mrp; +} + +static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) { + return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC); +} + +static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) { + return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST); +} + +static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor); + + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1); + } + + return ggml_mem_ranges_add(mrs, mrp); +} + +static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor); + + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1); + } + + return ggml_mem_ranges_add(mrs, mrp); +} + +bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->src[i]) { + ggml_mem_ranges_add_src(mrs, tensor->src[i]); + } + } + + return ggml_mem_ranges_add_dst(mrs, tensor); +} + +static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) { + for (size_t i = 0; i < mrs->ranges.size(); i++) { + const auto & cmp = mrs->ranges[i]; + + if (mrp.pb != cmp.pb) { + continue; + } + + if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) { + continue; + } + + if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) { + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n", + __func__, + mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", + mrp.pb, mrp.p0, mrp.p1, + cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", + cmp.pb, cmp.p0, cmp.p1); + } + + return false; + } + } + + return true; +} + +static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor); + + const bool res = ggml_mem_ranges_check(mrs, mrp); + + return res; +} + +static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor); + + const bool res = ggml_mem_ranges_check(mrs, mrp); + + return res; +} + +bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) { + return false; + } + } + } + + return ggml_mem_ranges_check_dst(mrs, tensor); +} + +// TODO: move to ggml.h? +static bool is_empty(ggml_op op) { + switch (op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_TRANSPOSE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + return true; + default: + return false; + } +} + +struct node_info { + ggml_tensor * node; + + std::vector fused; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + bool is_empty() const { + return ::is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } +}; + +static std::vector ggml_metal_graph_optimize_reorder(const std::vector & nodes) { + // helper to add node src and dst ranges + const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node.node->src[i]) { + if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) { + return false; + } + } + } + + for (const auto * fused : node.fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused->src[i]) { + if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) { + return false; + } + } + } + } + + return ggml_mem_ranges_add_dst(mrs, node.dst()); + }; + + // helper to check if a node can run concurrently with the existing set of nodes + const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node.node->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) { + return false; + } + } + } + + for (const auto * fused : node.fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) { + return false; + } + } + } + } + + return ggml_mem_ranges_check_dst(mrs, node.dst()); + }; + + // perform reorders only across these types of ops + // can be expanded when needed + // IMPORTANT: do not add ops such as GGML_OP_CPY or GGML_OP_SET_ROWS + // the dependencies from such ops are not always represented in the graph + const auto & h_safe = [](ggml_op op) { + switch (op) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_DIV: + case GGML_OP_GLU: + case GGML_OP_SCALE: + case GGML_OP_GET_ROWS: + return true; + default: + return is_empty(op); + } + }; + + const int n = nodes.size(); + + std::vector res; + res.reserve(n); + + std::vector used(n, false); + + ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0); + ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0); + + for (int i0 = 0; i0 < n; i0++) { + if (used[i0]) { + continue; + } + + const auto & node0 = nodes[i0]; + + // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0) + // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0 + // + // note: we can always add empty nodes to the concurrent set as they don't read nor write anything + if (!node0.is_empty() && !h_check(mrs0, node0)) { + // this will hold the set of memory ranges from the nodes that haven't been processed yet + // if a node is not concurrent with this set, we cannot reorder it + ggml_mem_ranges_reset(mrs1); + + // initialize it with the current node + h_add(mrs1, node0); + + // that many nodes forward to search for a concurrent node + constexpr int N_FORWARD = 8; + + for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { + if (used[i1]) { + continue; + } + + const auto & node1 = nodes[i1]; + + // disallow reordering of certain ops + if (!h_safe(node1.op())) { + break; + } + + const bool is_empty = node1.is_empty(); + + // to add a concurrent node, it has to be: + // + empty or concurrent with all nodes in the existing concurrent set (mrs0) + // + concurrent with all nodes prior to it that haven't been processed yet (mrs1) + if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) { + // add the node to the existing concurrent set (i.e. reorder it for early execution) + h_add(mrs0, node1); + res.push_back(i1); + + // mark as used, so we skip re-processing it later + used[i1] = true; + } else { + // expand the set of nodes that haven't been processed yet + h_add(mrs1, node1); + } + } + + // finalize the concurrent set and begin a new one + ggml_mem_ranges_reset(mrs0); + } + + // expand the concurrent set with the current node + { + h_add(mrs0, node0); + res.push_back(i0); + } + } + + ggml_mem_ranges_free(mrs0); + ggml_mem_ranges_free(mrs1); + + return res; +} + +void ggml_metal_graph_optimize(ggml_cgraph * gf) { + constexpr int MAX_FUSE = 16; + + const int n = gf->n_nodes; + + enum ggml_op ops[MAX_FUSE]; + + std::vector nodes; + nodes.reserve(gf->n_nodes); + + // fuse nodes: + // we don't want to make reorders that break fusing, so we first pack all fusable tensors + // and perform the reorder over the fused nodes. after the reorder is done, we unfuse + for (int i = 0; i < n; i++) { + node_info node = { + /*.node =*/ gf->nodes[i], + /*.fused =*/ {}, + }; + + // fuse only ops that start with these operations + // can be expanded when needed + if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_RMS_NORM) { + ops[0] = node.op(); + + int f = i + 1; + while (f < n && f < i + MAX_FUSE) { + // conservatively allow fusing only these ops + // can be expanded when needed + if (gf->nodes[f]->op != GGML_OP_ADD && + gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_RMS_NORM) { + break; + } + ops[f - i] = gf->nodes[f]->op; + f++; + } + + f -= i; + for (; f > 1; f--) { + if (ggml_can_fuse(gf, i, ops, f)) { + break; + } + } + + // add the fused tensors into the node info so we can unfuse them later + for (int k = 1; k < f; k++) { + ++i; + + // the .dst() becomes the last fused tensor + node.add_fused(gf->nodes[i]); + } + } + + nodes.push_back(std::move(node)); + } + + // reorder to improve concurrency +#if 1 + const auto order = ggml_metal_graph_optimize_reorder(nodes); +#else + std::vector order(nodes.size()); + for (size_t i = 0; i < nodes.size(); i++) { + order[i] = i; + } +#endif + + // unfuse + { + int j = 0; + for (const auto i : order) { + const auto & node = nodes[i]; + + gf->nodes[j++] = node.node; + + for (auto * fused : node.fused) { + gf->nodes[j++] = fused; + } + } + } +} diff --git a/src/ggml-metal/ggml-metal-common.h b/src/ggml-metal/ggml-metal-common.h new file mode 100644 index 00000000..c1402895 --- /dev/null +++ b/src/ggml-metal/ggml-metal-common.h @@ -0,0 +1,52 @@ +// helper functions for ggml-metal that are too difficult to implement in Objective-C + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_tensor; +struct ggml_cgraph; + +enum ggml_mem_range_type { + MEM_RANGE_TYPE_SRC = 0, + MEM_RANGE_TYPE_DST = 1, +}; + +// a helper object that can be used for reordering operations to improve concurrency +// +// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they +// don't write to a memory that is being read by another task or written to by another task in the set +// +// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task +// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the +// tasks already in the set) +// +struct ggml_mem_ranges; + +struct ggml_mem_ranges * ggml_mem_ranges_init(int debug); +void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs); + +// remove all ranges from the set +void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs); + +// add src or dst ranges to track +bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor); + +// return false if: +// - new src range overlaps with any existing dst range +// - new dst range overlaps with any existing range (src or dst) +bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor); + +// reorder the nodes in the graph to improve concurrency, while respecting fusion +// +// note: this implementation is generic and not specific to metal +// if it proves to work well, we can start using it for other backends in the future +void ggml_metal_graph_optimize(struct ggml_cgraph * gf); + +#ifdef __cplusplus +} +#endif diff --git a/src/ggml-metal/ggml-metal.m b/src/ggml-metal/ggml-metal.m index 6a42b3d7..b8e06aa6 100644 --- a/src/ggml-metal/ggml-metal.m +++ b/src/ggml-metal/ggml-metal.m @@ -3,6 +3,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" #import "ggml-metal-impl.h" +#import "ggml-metal-common.h" #import @@ -61,8 +62,11 @@ static struct ggml_backend_metal_device_context { bool has_bfloat; bool use_bfloat; bool use_fusion; + bool use_concurrency; bool use_shared_buffers; + bool use_graph_optimize; + int debug_graph; int debug_fusion; // how many times a given op was fused @@ -83,7 +87,10 @@ static struct ggml_backend_metal_device_context { /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, /*.use_fusion =*/ true, + /*.use_concurrency =*/ true, /*.use_shared_buffers =*/ true, + /*.use_graph_optimize =*/ true, + /*.debug_graph =*/ 0, /*.debug_fusion =*/ 0, /*.fuse_cnt =*/ { 0 }, /*.max_size =*/ 0, @@ -124,7 +131,14 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev #else ctx->use_bfloat = false; #endif - ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; + + ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; + ctx->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil; + + { + const char * val = getenv("GGML_METAL_GRAPH_DEBUG"); + ctx->debug_graph = val ? atoi(val) : 0; + } { const char * val = getenv("GGML_METAL_FUSION_DEBUG"); @@ -137,6 +151,12 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev ctx->use_shared_buffers = false; } + ctx->use_graph_optimize = true; + + if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) { + ctx->use_graph_optimize = false; + } + memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt)); ctx->max_size = ctx->mtl_device.maxBufferLength; @@ -628,7 +648,7 @@ static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { @end // -// ggml_metal_mem_pool +// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE] // struct ggml_metal_mem_pool { @@ -791,6 +811,9 @@ struct ggml_metal_command_buffer { // each command buffer has a memory pool from which it can allocate temporary buffers during the compute struct ggml_metal_mem_pool * mem_pool; + + // used to enable concurrent execution of ops in the command buffers + struct ggml_mem_ranges * mem_ranges; }; struct ggml_backend_metal_context { @@ -1091,7 +1114,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false"); + GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, ctx_dev->use_concurrency ? "true" : "false"); GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false"); + GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, ctx_dev->use_graph_optimize ? "true" : "false"); GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); ctx->capture_next_compute = false; @@ -1105,6 +1130,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init(); ctx->cmd_bufs[i].mem_pool->device = device; + + if (ctx_dev->use_concurrency) { + ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph); + } } ctx->cmd_bufs_ext = [[NSMutableArray alloc] init]; @@ -1715,6 +1744,10 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { } ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); + + if (ctx->cmd_bufs[i].mem_ranges) { + ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges); + } } [ctx->cmd_bufs_ext removeAllObjects]; @@ -2071,12 +2104,51 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex } } -static int ggml_metal_encode_node( - ggml_backend_t backend, - int idx, - int idx_end, - id encoder, - struct ggml_metal_mem_pool * mem_pool) { +struct ggml_metal_encode_context { + ggml_backend_t backend; + + id encoder; + + struct ggml_metal_mem_pool * mem_pool; + + struct ggml_mem_ranges * mem_ranges; +}; + +static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) { + if (!ctx->mem_ranges) { + return true; + } + + [ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; + + ggml_mem_ranges_reset(ctx->mem_ranges); + + return true; +} + +static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) { + if (!ctx->mem_ranges) { + return false; + } + + return ggml_mem_ranges_check(ctx->mem_ranges, node); +} + +static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) { + if (!ctx->mem_ranges) { + return true; + } + + return ggml_mem_ranges_add(ctx->mem_ranges, node); +} + +static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) { + ggml_backend_t backend = ctx_enc->backend; + + id encoder = ctx_enc->encoder; + + struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool; + struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; @@ -2159,38 +2231,71 @@ static int ggml_metal_encode_node( const uint64_t nb2 = dst ? dst->nb[2] : 0; const uint64_t nb3 = dst ? dst->nb[3] : 0; + size_t offs_src[GGML_MAX_SRC]; + + id id_src[GGML_MAX_SRC]; + + enum ggml_type srct[GGML_MAX_SRC]; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + offs_src[i] = 0; + id_src[i] = node->src[i] ? ggml_metal_get_buffer(node->src[i], &offs_src[i]) : nil; + srct[i] = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT; + } + + // TODO: tmp shorthands - remove + size_t offs_src0 = offs_src[0]; + size_t offs_src1 = offs_src[1]; + size_t offs_src2 = offs_src[2]; + + id id_src0 = id_src[0]; + id id_src1 = id_src[1]; + id id_src2 = id_src[2]; + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; + size_t offs_dst = 0; - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; int n_fuse = 1; -#if 0 - GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - if (src0) { - GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, - ggml_is_contiguous(src0), src0->name); - } - if (src1) { - GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, - ggml_is_contiguous(src1), src1->name); - } - if (dst) { - GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, - dst->name); + // check if the current node can run concurrently with other nodes before it + // the condition is that: + // - the current node cannot write to any previous src or dst ranges + // - the current node cannot read from any previous dst ranges + // + // if the condition is not satisfied, we put a memory barrier and clear all ranges + // otherwise, we add the new ranges to the encoding context and process the node concurrently + // + { + const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node); + + if (!is_concurrent) { + ggml_metal_encode_concurrency_reset(ctx_enc); + } + + if (ctx_dev->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(dst->op), is_concurrent ? "(concurrent)" : ""); + } + if (ctx_dev->debug_graph > 1) { + if (src0) { + GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(src0), src0->name); + } + if (src1) { + GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(src1), src1->name); + } + if (dst) { + GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + dst->name); + } + } } -#endif id device = ctx_dev->mtl_device; @@ -2389,6 +2494,14 @@ static int ggml_metal_encode_node( if (n_fuse > 1) { id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) { + ggml_metal_encode_concurrency_reset(ctx_enc); + + break; + } + } } [encoder setComputePipelineState:pipeline]; @@ -2533,6 +2646,8 @@ static int ggml_metal_encode_node( const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + ggml_metal_encode_concurrency_reset(ctx_enc); } const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; @@ -3997,6 +4112,12 @@ static int ggml_metal_encode_node( default: break; } + // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool + // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer. + // so we add this extra barrier to prevent the race. + // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE] + ggml_metal_encode_concurrency_reset(ctx_enc); + // tokens per expert const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); @@ -4057,6 +4178,9 @@ static int ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)]; } + // this barrier is always needed because the next kernel has to wait for the id maps to be computed + ggml_metal_encode_concurrency_reset(ctx_enc); + { id pipeline = nil; @@ -4525,6 +4649,14 @@ static int ggml_metal_encode_node( if (n_fuse > 1) { id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) { + ggml_metal_encode_concurrency_reset(ctx_enc); + + break; + } + } } id pipeline; @@ -4668,7 +4800,6 @@ static int ggml_metal_encode_node( } break; case GGML_OP_ROPE: { - // make sure we have one or more position id(ne10) per token(ne02) GGML_ASSERT(ne10 % ne02 == 0); GGML_ASSERT(ne10 >= ne02); @@ -5427,6 +5558,10 @@ static int ggml_metal_encode_node( GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31)); + // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE] + // still, we assume that concurrent FA won't happen before we do the refactor + //ggml_metal_encode_concurrency_reset(ctx_enc); + const int32_t nrows = ne1*ne2*ne3; // temp buffer for writing the results from each workgroup @@ -5447,6 +5582,8 @@ static int ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + ggml_metal_encode_concurrency_reset(ctx_enc); + // reduce the results from the workgroups { ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { @@ -5677,7 +5814,7 @@ static int ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; } break; - case GGML_OP_ARGMAX: + case GGML_OP_ARGMAX: { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous_1(src0)); @@ -5709,6 +5846,19 @@ static int ggml_metal_encode_node( } } + if (ctx_dev->debug_graph > 0) { + if (n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse); + } + } + + // update the mem ranges in the encoding context + for (int i = 0; i < n_fuse; ++i) { + if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) { + ggml_metal_encode_concurrency_reset(ctx_enc); + } + } + return n_fuse; } @@ -5719,7 +5869,7 @@ static enum ggml_status ggml_metal_graph_compute( struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; // number of nodes encoded by the main thread (empirically determined) - const int n_main = 128; + const int n_main = 64; // number of threads in addition to the main thread const int n_cb = ctx->n_cb; @@ -5774,6 +5924,7 @@ static enum ggml_status ggml_metal_graph_compute( // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009 + // [TAG_MEM_POOL_REMOVE] //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; id cmd_buf = [ctx->queue commandBuffer]; [cmd_buf retain]; @@ -6547,6 +6698,18 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, return ggml_metal_graph_compute(backend, cgraph); } +static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + //const int64_t t_start = ggml_time_us(); + + if (ctx_dev->use_graph_optimize) { + ggml_metal_graph_optimize(cgraph); + } + + //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); +} + static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { GGML_ASSERT(ggml_backend_is_metal(backend)); @@ -6573,12 +6736,25 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; - id cmd_buf = ctx->cmd_bufs[cb_idx].obj; - struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges; ggml_metal_mem_pool_reset(mem_pool); - id encoder = [cmd_buf computeCommandEncoder]; + if (mem_ranges) { + ggml_mem_ranges_reset(mem_ranges); + } + + id encoder; + + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + if (ctx_dev->use_concurrency) { + encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } else { + encoder = [cmd_buf computeCommandEncoder]; + } int node_start = 0; int node_end = n_nodes_0; @@ -6590,12 +6766,19 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const bool should_capture = ctx->capture_next_compute; + struct ggml_metal_encode_context ctx_enc = { + /*.backend =*/ backend, + /*.encoder =*/ encoder, + /*.mem_pool =*/ mem_pool, + /*.mem_ranges =*/ mem_ranges, + }; + for (int idx = node_start; idx < node_end;) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool); + const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end); if (idx + res > node_end) { GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", "https://github.com/ggml-org/llama.cpp/pull/14849"); @@ -6638,7 +6821,7 @@ static struct ggml_backend_i ggml_backend_metal_i = { // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events /* .event_record = */ NULL, /* .event_wait = */ NULL, - /* .optimize_graph = */ NULL, + /* .optimize_graph = */ ggml_backend_metal_graph_optimize, }; static ggml_guid_t ggml_backend_metal_guid(void) {